Expected tensor for argument #1 ‘indices’ to have scalar type Long;but got torch.IntTensor instead
构建数据集代码
def data_gen(V, batch, nbatches):
"Generate random data for a src-tgt copy task."
for i in range(nbatches):
data = torch.from_numpy(np.random.randint(1, V, size=(batch, 10)))#单句长度为10
data[:, 0] = 1#第一列为1
src = Variable(data, requires_grad=False)
tgt = Variable(data, requires_grad=False)
yield Batch(src, tgt, 0)
训练时报错:
分析:训练的批量样本数据输入值需要是long值的Tensor数据,而不是int值的Tensor数据.
解决(亲测有效):
1.把输入数据data和目标数据data的类型值都从int转换为long
1.可以加long()
source = Variable(data, requires_grad=False).long()
target = Variable(data, requires_grad=False).long()
2.也可以加torch.LongTensor
source = Variable(torch.LongTensor(data), requires_grad=False)
target = Variable(torch.LongTensor(data), requires_grad=False)
修改后代码:
def data_gen(V, batch, nbatches):
"Generate random data for a src-tgt copy task."
for i in range(nbatches):
data = torch.from_numpy(np.random.randint(1, V, size=(batch, 10)))#单句长度为10
data[:, 0] = 1#第一列为1
src = Variable(data, requires_grad=False).long()
tgt = Variable(data, requires_grad=False).long()
yield Batch(src, tgt, 0)
运行成功!