修改DataLoader的generator(问题解决)
data_loader = data.DataLoader()的参数最后加了generator=torch.Generator(device = 'cuda')
修改将torch.set_default_tensor_type
将torch.set_default_tensor_type(‘torch.FloatTensor’)改为:torch.set_default_tensor_type(‘torch.cuda.FloatTensor’)。
参考:
RuntimeError: Expected a ‘cuda‘ device type for generator but found ‘cpu‘_alex1801的博客-CSDN博客