在使用nn.Embedding的时候最容易出现IndexError: index out of range in self
错误,原因在于tensor中的值越界了。
官方解释中:
num_embeddings (int) – size of the dictionary of embeddings
embedding_dim (int) – the size of each embedding vector
上述num_embeddings值通常是输入的数量值,因为输入通常从0开始编码,比如编码到0-499,那么一共是500个数,因此num_embeddings=500,可以保证一定不会越界。
但是实际上,num_embeddings是1000也可以的,关键在于输入中的值不要超过num_embeddings这个值即可。因此如果只有500个数,但是输入值并不是从0-499编码,其中如果一旦有值超过了499,那么也会报错。
以及注意,输入的tensor必须是整型(long,int)
例子:
import torch.nn as nn
import torch
user_embeddings = nn.Embedding(11, 50)
print(user_embeddings.num_embeddings) # num_embeddings=11
a = torch.Tensor([10,10,10,10]).long()
print(a.min()) # 10
print(a.max()) # 10
u_embs = user_embeddings(a)
上述输入的最大值和最小值都没有超过num_embeddings,是可以的;但是如果将num_embeddings改为10,就会报错。
另一个例子:
user_embeddings = nn.Embedding(2, 50)
print(user_embeddings.num_embeddings) # num_embeddings=2
a = torch.ones([10,10,10,10]).long()
print(a.min()) # 1
print(a.max()) # 1
u_embs = user_embeddings(a)
上述可以发现,将num_embeddings改为2,输入的最大值和最小值为1,同样没有超过,也是可以的。
所以,注意num_embeddings限制的是输入量中的值,而不是输入的维度大小。
参考:
- pytorch官方nn.Embedding文档(Docs > torch.nn > Embedding):https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html?highlight=nn%20embedding#torch.nn.Embedding
- pytorch embedding层报错IndexError: index out of range in self:https://blog.csdn.net/weixin_36488653/article/details/118485063