官方推荐使用:
# 保存网络中的参数, 速度快,占空间少
torch.save(model.state_dict(),PATH)
对应的加载模型代码则为
model_dict=model.load_state_dict(torch.load(PATH))
示例
#训练模型
# Define model
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
...
def forward(self, x):
...
return x
# Initialize model
model = TheModelClass()
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
model.train()
torch.save({'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()},PATH)
#测试模型
#load model
def load_checkpoint(model, checkpoint_PATH, optimizer):
if checkpoint != None:
model_CKPT = torch.load(checkpoint_PATH)
model.load_state_dict(model_CKPT['state_dict'])
print('loading checkpoint!')
optimizer.load_state_dict(model_CKPT['optimizer'])
return model, optimizer
model = TheModelClass()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
model, optimizer=load_checkpoint(model, PATH, optimizer)
model.test()
此时若出现报错“Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!”。则说明存储和加载模型使用了不同的设备。
1)GPU保存,CPU加载
加载模型代码
device=torch.device("cup")
model= TheModelClass()
model.load_state_dict(torch.load(PATH,map_location=device))
2)GPU保存,GPU加载
加载模型代码
device=torch.device("cuda")
model= TheModelClass()
model.load_state_dict(torch.load(PATH))
model.to(device)
3)CPU保存,GPU加载
加载模型代码
device=torch.device("cuda")
model= TheModelClass()
model.load_state_dict(torch.load(PATH,map_location="cuda:0"))
model.to(device)