在 PyTorch 的 Module
和 Tensor
使用 to()
有所区别, Module
对象只需要调用 to(device)
不用接受返回值, Tensor
对象需要接收返回值.
Module
-
定义
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-
使用
model.to(device)
-
查看
print(next(model.parameters()).device)
Tensor
-
定义
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-
使用
data = data.to(device)
-
查看
print(data.device)