nn.Embedding

在使用nn.Embedding的时候最容易出现IndexError: index out of range in self错误,原因在于tensor中的值越界了。

官方解释中:

num_embeddings (int) – size of the dictionary of embeddings
embedding_dim (int) – the size of each embedding vector

上述num_embeddings值通常是输入的数量值,因为输入通常从0开始编码,比如编码到0-499,那么一共是500个数,因此num_embeddings=500,可以保证一定不会越界。

但是实际上,num_embeddings是1000也可以的,关键在于输入中的值不要超过num_embeddings这个值即可。因此如果只有500个数,但是输入值并不是从0-499编码,其中如果一旦有值超过了499,那么也会报错。

以及注意,输入的tensor必须是整型(long,int)

例子:

import torch.nn as nn
import torch

user_embeddings = nn.Embedding(11, 50)
print(user_embeddings.num_embeddings)  # num_embeddings=11

a = torch.Tensor([10,10,10,10]).long()
print(a.min())  # 10
print(a.max())  # 10

u_embs = user_embeddings(a)

上述输入的最大值和最小值都没有超过num_embeddings,是可以的;但是如果将num_embeddings改为10,就会报错。

另一个例子:

user_embeddings = nn.Embedding(2, 50)
print(user_embeddings.num_embeddings)  # num_embeddings=2

a = torch.ones([10,10,10,10]).long()
print(a.min())  # 1
print(a.max())  # 1

u_embs = user_embeddings(a)

上述可以发现,将num_embeddings改为2,输入的最大值和最小值为1,同样没有超过,也是可以的。

所以,注意num_embeddings限制的是输入量中的值,而不是输入的维度大小。

参考:

  1. pytorch官方nn.Embedding文档(Docs > torch.nn > Embedding):https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html?highlight=nn%20embedding#torch.nn.Embedding
  2. pytorch embedding层报错IndexError: index out of range in self:https://blog.csdn.net/weixin_36488653/article/details/118485063
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值