核心观念,运算中的数据类型应该一致!
比如你的dataset放到GPU上面 那么它的数据就是torch.cuda.FloatTensor
如果你的模型里是含有weight可学习参数,你又没有把他放到gpu上面
他的数据类型就是torch.FloatTensor
,这就不可以进行运算。
自定义损失函数即使继承了 nn.Module
但没有添加新的参数,比如:
class loss(nn.Module):
def__init__(self):
super().__init__()