报错如下:
assert all(map(lambda i: i.is_cuda, inputs))
AssertionError
报错原因:
pytorch框架代码在服务器上运行时,服务器上有好几GPU,程序暂时不知道怎么用多个gpu(或者说不知道用哪个GPU)
解决方法:
方法1:
import os
os.environ[“CUDA_VISIBLE_DEVICES”] = "0" # 指定一个GPU编号,这里例子写的0,也可以是1,2,...
方法2:
net = torch.nn.DataParallel(model,device_ids=[0])
# 通过device_ids进行指定
以上两种方法使用之后就不再报错了,但是只能使用一个GPU,多个GPU使用,并且不报错,还在研究中。。。。
参考:
https://blog.csdn.net/qq_32780465/article/details/106300394
https://blog.csdn.net/baoyongshuai1509/article/details/85254298
https://blog.csdn.net/kongkongqixi/article/details/100521590