pytorch训练出现RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.FloatTensor)的错误,一般是输入数据和网络数据不匹配。
解决方法:
输入数据为:out
device = out.data.device
weight = torch.ones([3,128,3,3]).to(device)
重新训练,程序正常执行。
pytorch训练出现RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.FloatTensor)的错误,一般是输入数据和网络数据不匹配。
解决方法:
输入数据为:out
device = out.data.device
weight = torch.ones([3,128,3,3]).to(device)
重新训练,程序正常执行。