模型保存
torch.save(model, os.path.join(checkpoint_dir, f'model_{epoch}.pth'))
模型加载
checkpoint = torch.load(checkpoint_file)
model=checkpoint.to(device)
只加载变量:
checkpoint = torch.load(lists[epoch])
model.load_state_dict(checkpoint.state_dict())
当模型训练使用的显卡编号与当前使用的不同时,需要新建一个OrderedDict
from collections import OrderedDict
checkpoint = torch.load(checkpoint_file)
model = InceptionResnetV1(pretrained='vggface2', classify=True, num_classes=dataset.identity.max().numpy()+1)
new_state_dict = OrderedDict()
for k, v in checkpoint.state_dict().items():
name = k[7:] # remove module.
print(f'Loading parameters of Layer: {name}') # check if model loaded
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model = model.to(device)