一般使用torch.nn.Embedding(num_embeddings: int, embedding_dim: int)时只用到前两个参数
num_embeddings表示嵌入的字典个数,如果输入的的是数组,那么num_embeddings至少要比数组中最大的元素要大
否则,会出现IndexError: index out of range in self
# embedding = nn.Embedding(10, 3)
# 前一个数至少要比输入元素中最大值要大
embedding = nn.Embedding(8, 4)
inputs = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
print(inputs.shape)
outputs = embedding(inputs)
print(outputs.shape)
在上面的代码中,输入数组的最大元素是9。
当设置nn.Embedding(10, 3) 时,能够正常运行,得到输出
torch.Size([2, 4])
torch.Size([2, 4, 3])
当设置nn.Embedding(8, 4)时,报错IndexError: index out of range in self。
第二个参数表示每一个嵌入向量的大小。
nn.Embedding的输入只能是LongTensor,大小为 (batch_size, sequence_length)
输出的大小为 (batch_size, sequence_length, embedding_dim),即输出在输入后增加了隐藏层的维度
参考:
https://yifdu.github.io/2018/12/05/Embedding%E5%B1%82/
https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html
https://discuss.pytorch.org/t/embedding-error-index-out-of-range-in-self/81550/2