原因参考:
输入的数据类型为torch.cuda.FloatTensor,说明输入数据在GPU中
模型参数的数据类型为torch.FloatTensor,说明模型还在CPU
解决方法:
加上.to(device)
如:
net = Net(in_ch=3, out_ch=16, hid_ch=32).to(device)
原因参考:
输入的数据类型为torch.cuda.FloatTensor,说明输入数据在GPU中
模型参数的数据类型为torch.FloatTensor,说明模型还在CPU
解决方法:
加上.to(device)
如:
net = Net(in_ch=3, out_ch=16, hid_ch=32).to(device)