PyTorch 7.保存和加载pytorch模型的方法

保存和加载模型

python的对象都可以通过torch.save和torch.load函数进行保存和加载

x1 = {"d":"df","dd":"ddf"}
torch.save(x1,'a1.pt')
x2 = torch.load('a1.pt')

下面来谈模型的state_dict(),该函数返回模型的所有参数

class MLP(nn.Module):
	def __init__(self):
		super(MLP,self).__init__()
		self.hidden = nn.Linear(3,2)
		self.act = nn.ReLU()
		self.output = nn.Linear(2,1)
	def forward(self,x):
		a = self.act(self.hidden(x))
		return self.output(a)
net = MLP()
net.state_dict()

输出

OrderedDict([('hidden.weight',
              tensor([[-0.4195,  0.2609,  0.4325],
                      [-0.4031,  0.2078,  0.2077]])),
             ('hidden.bias', tensor([ 0.0755, -0.1408])),
             ('output.weight', tensor([[0.2473, 0.6614]])),
             ('output.bias', tensor([0.6191]))])

torch.save&torch.load

  1. 保存整个模型
    如果选择保存模型,那么可以不需要预先创建模型的实例,可以直接加载模型及其参数
torch.save(net,path)
  1. 保存模型参数
    选择保存模型参数,在加载时需要先创建模型实例
state_dict = net.state_dict()
torch.save(state_dict,path)
  1. 模型finetune
    如果模型训练中不小心中断了,或者需要用该模型去其他模型进行finetune。我们不仅要保存模型参数,还需要保存模型的训练周期及优化器参数。
    这里,我们经常会看到一个叫checkpoint的东东,它其实是一个字典
checkpoint = {
	"model_state_dict":net.state_dict(),
	"optimizer_state_dict":optimizer.state_dict(),
	"epoch":epoch
}
# 保存
torch.save(checkpoint, path_checkpoint)
# 加载
checkpoint = torch.load(path_checkpoint)
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
scheduler.last_epoch = start_epoch

跨设备保存加载模型

  1. 在CPU上加载在GPU上训练并保存的模型:
device = torch.device('cpu')
model = MyModel()
model.load_state_dict(torch.load('net_params.pth', map_location=device))
  1. 在GPU上加载在GPU上训练并保存的模型:
device = torch.device('cuda')
model = MyModel()
model.load_state_dict(torch.load('net_params.pth'))
model.to(device)

在这里使用map_location参数不起作用,要使用model.to(torch.device(“cuda”))将模型转换为CUDA优化的模型
数据也要转换到GPU
由于my_tensor.to(device)会返回一个my_tensor在GPU上的副本,它不会覆盖my_tensor

my_tensor = my_tensor.to(device)

存在多个GPU设备时
map_location指定tensor加载的GPU序号

model.load_state_dict(torch.load('net_params.pth'),map_location='cuda:0')

多GPU训练,单GPU加载

def load_model(model, model_path, optimizer=None,resume=False,
lr=None, lr_step=None):
	start_epoch = 0
	checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
	print('loaded {}, epoch {}'.format(model_path, checkpoint['epoch']))
	state_dict_ = checkpoint['state_dict']
	state_dict = {}
	# 将data_parallal转换到model
	for k in state_dict_:
		if k.startswith('module') and not k.startswith('module_list'):
			state_dict[k[7:]] = state_dict_[k]
		else:
			state_dict[k] = state_dict_[k]
	model_state_dict = model.state_dict()
# 加载模型,检查参数,创建模型参数
	msg = 'If you see this, your model does not fully load the pre-trained weight.'
	for k in state_dict:
		if k in model_state_dict:
			if state_dict[k].shape != model_state_dict[k].shape:
				print('Skip loading parameter {}, required shape{}, loaded shape{}.{}'.format(k,model_state_dict[k].shape,state_dict[k].shape,msg))
				state_dict[k] = model_state_dict[k]
		else:
			#参数在预加载模型中存在,在本地模型不存在,丢弃参数
			print('Drop parameter {}.'.format(k)+msg)
	for k in model_state_dict:
		if not (k in state_dict):
			# 本地模型需要的参数,预加载模型中不存在,则根据本地模型参数加载
			print('No param {}.'.format(k)+msg)
			state_dict[k] = model_state_dict[k]
	model.load_state_dict(state_dict,strict = False)

	# 断点继续训练
	if optimizer is not None and resume:
		if 'optimizer' in checkpoint:
			optimizer.load_state_dict(checkpoint['optimizer'])
			start_epoch = checkpoint['epoch']
			start_lr = lr
			for step in lr_step:
				if start_epoch >= step:
					start_lr *= 0.1
			for param_group in optimizer.param_groups:
				param_group['lr'] = start_lr
			print('Resumed optimizer with start lr', start_lr)
		else:
			print('No optimizer parameters in checkpoint.')
	if optimizer is not None:
		return model, optimizer, start_epoch
	else:
		return model	
	
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值