在我的代码中我遇到这种情况一般都是在torch.cat的时候
feats=[]
for i in range(pos_ind.size(0)):
if(pos_ind[i,0] == -1000):
# 原代码:feat=(torch.zeros(features.size(0),1))
feat=(torch.zeros(features.size(0),1)).to(self.device)#修改后
#self.device是网络前面自己设置的,具体看你自己用的gpu是什么
#可以用 device=torch.device('cuda:0')#假设你的gpu是cuda:0
feats.append(feat)
continue
feat = features[:, pos_ind[i, 0], pos_ind[i, 1]]
feat=torch.unsqueeze(feat,1)
feat=feat.to(self.device)
feats.append(feat)
feats=torch.cat(feats,dim=1)#报错的地方