pytorch : IndexError: scatter_(): Expected dtype int64 for index.
1.问题产生原因及解决方法
scatte_r()要求数据是int64类型,检查传入scatter()函数的tensor 类型是不是int64,假如不是进行修改
data = torch.from_numpy(np.random.randint(1, V, size=(batch, 10)))
data[:, 0] = 1
data = torch.tensor(data,dtype = torch.int64)
torch.from_numpy()产生的tensor类型为int32,因此需要用data = torch.tensor(data,dtype = torch.int64)进行数据类型的转换,然后就可以运用了
2.IndexError: invalid index of a 0-dim tensor. Use tensor.item()
in Python or tensor.item<T>()
in C++ to convert a 0-dim tensor to a number
return loss.data[0] * norm
修改为
return loss.item() * norm