报错截图
网上有很多关于这个错误的解决方法,出错原因都不一样。我的出错原因:
def step(self, X, Y):
X = X.flatten(1)
b, _ = X.shape
deltaW = self.lr * (torch.mm(X.T,Y) - self.inhibit)/b
self.weight.data.add_(deltaW)
这里我的X,Y是输入,模型用了GPU。在输入时,X转到GPU上了,而Y忘了。所以在计算deltaW的时候,就已经出问题了。
将Y也转到GPU上就可以了。Y.cuda() 或者 Y.to(device).