解决pytorch模型加载时gpu id的限制
问题描述
刚开始接触pytorch时,发现每次调用训练好的模型,总是被原来训练时使用的第几个gpu限制。
假如我训练模型时,用的是第3号gpu。那么在测试模型时,加载模型时直接使用GPU的话,就会被限制使用第3号gpu才能运行。
解决方案
首先将模型加载到cpu,然后再使用GPU。
代码示例如下:
net = XXXXXX
net.load_state_dict(torch.load(XXX.pth.gz, map_location="cpu")) # 首先,将模型放到CPU
net.eval()
net.to("cuda:0") # 然后,再使用GPU