pytorch保存的pth
模型其实就是一个dict
类型的二进制文件。
1. 保存模型
def save_checkpoint(network, fname, amp_grad_scaler, optimizer, epoch, save_optimizer=True):
state_dict = network.state_dict()
# convert state dict to cpu
for key in state_dict.keys():
state_dict[key] = state_dict[key].cpu()
# extract optimizer parameters
if save_optimizer:
optimizer_state_dict = optimizer.state_dict()
else:
optimizer_state_dict = None
lr_sched_state_dct = None
lr_scheduler = None
if lr_scheduler is not None and
hasattr(lr_scheduler, 'state_dict'): # not isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau):
lr_sched_state_dct = lr_scheduler.state_dict()
# WTF is this!?
# for key in lr_sched_state_dct.keys():
# lr_sched_state_dct[key] = lr_sched_state_dct[key]
save_this = {
'epoch': epoch,
'state_dict': state_dict,
'optimizer_state_dict': optimizer_state_dict,
'lr_scheduler_state_dict': lr_sched_state_dct,
}
if amp_grad_scaler is not None:
save_this['amp_grad_scaler'] = amp_grad_scaler.state_dict()
torch.save(save_this, fname)
save_checkpoint(net, "model.pth", None, None, epoch+1, None)
2. 模型加载
def load_checkpoint(net, fname, optimizer=None, amp_grad_scaler=None, fp16=True, train=True):
saved_model = torch.load(fname, map_location=torch.device('cpu'))
new_state_dict = OrderedDict()
curr_state_dict_keys = list(net.state_dict().keys())
# if state dict comes form nn.DataParallel but we use non-parallel model here then the state dict keys do not
# match. Use heuristic to make it match
for k, value in saved_model['state_dict'].items():
key = k
if key not in curr_state_dict_keys and key.startswith('module.'):
key = key[7:]
new_state_dict[key] = value
if fp16:
if fp16 and amp_grad_scaler is None:
amp_grad_scaler = GradScaler()
if 'amp_grad_scaler' in saved_model.keys():
amp_grad_scaler.load_state_dict(saved_model['amp_grad_scaler'])
net.load_state_dict(new_state_dict)
epoch = saved_model['epoch']
if train:
optimizer_state_dict = saved_model['optimizer_state_dict']
if optimizer_state_dict is not None:
optimizer.load_state_dict(optimizer_state_dict)
return net, optimizer, amp_grad_scaler, epoch
epoch = 0
device = "cuda" if torch.cuda.is_available() else "cpu"
net = Model().to(device)
net = nn.DataParallel(net, device_ids=[0, 1])
optimizer = torch.optim.SGD(...)
amp_grad_scaler = GradScaler()
net, optimizer, amp_grad_scaler, epoch, load_checkpoint(net, "model.pth")
while epoch < max_num_epochs:
train_model(net)
注意: 建立optizier
和·GradScaler()
必须在model.to(device)
之后, 最后再调用load_checkpoint
,否则会出错。
3. pth文件内容
import torch
model_dict = torch.load("model.pth")
print(model_dict.keys())
"""
-> dict_keys([
'epoch',
'state_dict',
'optimizer_state_dict',
'lr_scheduler_state_dict',
'amp_grad_scaler'])
"""
print(model['amp_grad_scaler'])
"""
->dict({'scale': 524288.0,
'growth_factor': 2.0,
'backoff_factor': 0.5,
'growth_interval': 2000,
'_growth_tracker': 952})
"""
print(model['lr_scheduler_state_dict'])
# -> None
print(model['optimizer_state_dict'].keys())
# -> dict_keys(['state', 'param_groups'])
print(net['optimizer_state_dict']['state'].keys())
# -> dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51])
print(net['optimizer_state_dict']['param_groups'])
# -> [{'lr': 9.397949325887223e-05,
'momentum': 0.99,
'dampening': 0,
'weight_decay': 3e-05,
'nesterov': True,
'params': [0, 1, 2,...,51]
}]
print(optimizer['state'][0]['momentum_buffer'].shape)
# -> torch.Size([32, 1, 5, 5])
print(net['state_dict']['module.encoder.0.weight'].shape)
# -> torch.Size([32, 1, 5, 5])