针对PyTorch模型保存和加载时CPU和GPU之间的转换,举两个例子:
- 保存模型时转换到CPU:
python
模型初始化在GPU上
device = torch.device(“cuda”)
model = Model().to(device)
保存前转换到CPU
model.cpu()
torch.save(model.state_dict(), ‘model.pth’)
2. 加载模型时转换到GPU:
python
加载模型参数
model.load_state_dict(torch.load(‘model.pth’))
如果要加载到GPU上
device = torch.device(“cuda”)
model.to(device)
主要原因是torch.save和torch.load默认是存储在CPU内存中的,而模型在GPU上时参数是存储在GPU内存的。
所以保存前需要调用model.cpu()将参数移动到CPU内存,xn–gpumodel-hu2m38skqbr8hs8w9r0clq3bea0669fmpj3sq.to(device)将参数加载到GPU上。
这样可以正确保存和加载模型到CPU/GPU。