pytorch模型保存和加载及模型内容

pytorch保存的pth模型其实就是一个dict类型的二进制文件。

1. 保存模型

def save_checkpoint(network, fname, amp_grad_scaler, optimizer, epoch, save_optimizer=True):
    state_dict = network.state_dict()  
    # convert state dict to cpu
    for key in state_dict.keys():
        state_dict[key] = state_dict[key].cpu()

    # extract optimizer parameters
    if save_optimizer:
        optimizer_state_dict = optimizer.state_dict()
    else:
        optimizer_state_dict = None

    lr_sched_state_dct = None
    lr_scheduler = None
    if lr_scheduler is not None and 
    hasattr(lr_scheduler, 'state_dict'):  # not isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau):
        lr_sched_state_dct = lr_scheduler.state_dict()
        # WTF is this!?
        # for key in lr_sched_state_dct.keys():
        #    lr_sched_state_dct[key] = lr_sched_state_dct[key]

    save_this = {
        'epoch': epoch,
        'state_dict': state_dict,
        'optimizer_state_dict': optimizer_state_dict,
        'lr_scheduler_state_dict': lr_sched_state_dct,
        }
    if amp_grad_scaler is not None:
        save_this['amp_grad_scaler'] = amp_grad_scaler.state_dict()

    torch.save(save_this, fname)

save_checkpoint(net, "model.pth", None, None, epoch+1, None)

2. 模型加载

def load_checkpoint(net, fname, optimizer=None, amp_grad_scaler=None, fp16=True, train=True):
    saved_model = torch.load(fname, map_location=torch.device('cpu'))

    new_state_dict = OrderedDict()
    curr_state_dict_keys = list(net.state_dict().keys())
    # if state dict comes form nn.DataParallel but we use non-parallel model here then the state dict keys do not
    # match. Use heuristic to make it match
    for k, value in saved_model['state_dict'].items():
        key = k
        if key not in curr_state_dict_keys and key.startswith('module.'):
            key = key[7:]
        new_state_dict[key] = value

    if fp16:
        if fp16 and amp_grad_scaler is None:
            amp_grad_scaler = GradScaler()
        if 'amp_grad_scaler' in saved_model.keys():
            amp_grad_scaler.load_state_dict(saved_model['amp_grad_scaler'])

    net.load_state_dict(new_state_dict)
    epoch = saved_model['epoch']
    if train:
        optimizer_state_dict = saved_model['optimizer_state_dict']
        if optimizer_state_dict is not None:
            optimizer.load_state_dict(optimizer_state_dict)

    return net, optimizer, amp_grad_scaler, epoch

epoch = 0
device = "cuda" if torch.cuda.is_available() else "cpu"
net = Model().to(device)
net = nn.DataParallel(net, device_ids=[0, 1])

optimizer = torch.optim.SGD(...)
amp_grad_scaler = GradScaler()
net, optimizer, amp_grad_scaler, epoch, load_checkpoint(net, "model.pth")
while epoch < max_num_epochs:
	train_model(net)

注意: 建立optizier和·GradScaler()必须在model.to(device)之后, 最后再调用load_checkpoint,否则会出错。

3. pth文件内容

import torch
model_dict = torch.load("model.pth")
print(model_dict.keys())
"""
-> dict_keys([
	'epoch', 
	'state_dict', 
	'optimizer_state_dict', 
	'lr_scheduler_state_dict', 
	'amp_grad_scaler'])
"""

print(model['amp_grad_scaler'])
"""
->dict({'scale': 524288.0,
 'growth_factor': 2.0,
 'backoff_factor': 0.5,
 'growth_interval': 2000,
 '_growth_tracker': 952})
"""

print(model['lr_scheduler_state_dict'])
# -> None

print(model['optimizer_state_dict'].keys())
# -> dict_keys(['state', 'param_groups'])

print(net['optimizer_state_dict']['state'].keys())
# -> dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51])

print(net['optimizer_state_dict']['param_groups'])
# -> [{'lr': 9.397949325887223e-05,
  'momentum': 0.99,
  'dampening': 0,
  'weight_decay': 3e-05,
  'nesterov': True,
  'params': [0, 1, 2,...,51]
  }]

print(optimizer['state'][0]['momentum_buffer'].shape)
# -> torch.Size([32, 1, 5, 5])

print(net['state_dict']['module.encoder.0.weight'].shape)
# -> torch.Size([32, 1, 5, 5])

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值