loss = loss + loss1 + loss2 + loss3 + loss4 + loss5 + loss6 + loss7
RuntimeError: expected device cuda:0 and dtype Float but got device cpu and dtype Float
提示的很明显,expected device cuda 但数据走的cpu, 解决方法如下,把各变量后加个 .cuda():
loss = loss.cuda()
sub_loss1 = sub_loss1.cuda()
或者:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
loss = loss.to(device)
loss1 = loss1 .to(device)
..........
loss = loss + loss1 + loss2 + loss3 + loss4 + loss5 + loss6 + loss7
详见博客:
https://blog.csdn.net/weixin_43786241/article/details/107239616