【pytorch】模型保存与加载

前言

学习pytorch框架,掌握模型保存与加载是必不可少的环节,本文记录pytorch模型保存与加载的主要工具,和各种使用场景。
 

主要工具

torch.save

torch.save方法使用python中的pickle库,保存对象到硬盘文件,注意对象不仅限于模型,还可以保存tensor、字典等其它对象。其签名如下:

torch.save(obj, f: Union[str, os.PathLike, BinaryIO], 
pickle_module=<module 'pickle' from '/opt/conda/lib/python3.6/pickle.py'>, 
pickle_protocol=2, _use_new_zipfile_serialization=True)None

一般指定需保存的对象,和保存的路径字符串就可以了,注意pytorch保存对象的文件名后缀一般为“.pt” 或者".pth"。
 

torch.load

torch.load方法使用python中的pickle库反序列化功能,将硬盘中的文件内容加载到内存,其签名如下:

torch.load(f, map_location=None, pickle_module=<module 'pickle' from '/opt/conda/lib/python3.6/pickle.py'>, 
 **pickle_load_args)

加载对象时,对象首先在“CPU”上被反序列化,然后移动到torch.save保存它们时的设备上。该过程可能会抛出错误——比如执行torch.save时的设备,在执行torch.load时该设备已不存在。出现这种情况可以使用关键词参数map_location指定对象新的寄存地点。

map_location可以是“可调用对象、torch.device、字符串或者字典”。当其为可调用对象时,格式为lambda storage location: storage或者lambda storage location: None。默认情况下,加载的对象会有两个值,“其一加载的对象表示为storage, 其二保存对象时的设备信息为location,如果是CPU型的tensor,其location值为cpu,如果为GPU型的tensor,其location值为cuda:idx”。如使用第一种方式,则会取tensor保存时的设备信息,如果报错,则自动使用第二种方式,相当于忽略指定了map_location关键字参数。

指定torch.device或者位置字符串(e.g. cpu | cuda:0)会将所有tensor移动到指定位置,指定字典(e.g. {‘cuda:0’: ‘cuda:1’}, 将cuda:0上的tensor转移到’cuda:1’上。综上所述,map_location就是用来指定加载的对象到底放到哪个设备上。

# 加载所有的tensors到CPU
torch.load('tensors.pt', map_location=torch.device('cpu'))

# 使用函数加载所有的tensor到cpu, 假设tensor保存时是位于cpu
torch.load('tensors.pt', map_location=lambda storage, loc: storage)

# 加载所有的tensor到GPU1
torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))

# 将GPU1上的tensor移动到GPU0
torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})

 

nn.Module.state_dict

注意这是pytorch模型的方法,该方法返回一个字典,其键为parameter与persistent buffers的名称,字典中的值就是保存的tensor。其签名为:

state_dict(destination=None, prefix='', keep_vars=False)

不指定destination时,直接返回到调用处。
 

nn.Module.load_state_dict

该方法根据参数state_dict中的键与nn.Module自己的state_dict中的键,将state_dict参数中的tensor复制到nn.Module中。默认方式为“精确匹配”, 可以指定关键词参数“strict=False”改变默认方式。

“精确匹配”方式下,如果参数state_dict中包含nn.Module的state_dict中不存在的键,则会报出“unexpected _keys的错误”,如果参数state_dict不包含nn.Module的state_dict所需的键值对,则会报出“missing_keys”错误。

当指定“非精确匹配”时,会返回一个元组,元组中包含unexpected _keys与missing_keys的列表,用于收集不能匹配的键,非精确模式非常适合加载部分模型参数。

load_state_dict(state_dict: Dict[str, torch.Tensor], strict: bool = True)

 

各种场景下的模型保存与加载

保存与加载模型用于推断

有两种方式,推荐使用第一种方式——“以状态字典的方式保存或加载模型参数”。保存状态字典:

torch.save(model.state_dict(), PATH)

加载状态字典:

model.load_state_dict(torch.load(PATH))

第二种方式是保存与加载整个模型。保存模型:

torch.save(model, PATH)

加载模型:

torch.load(PATH)

第二种方式的优点就是非常直观,但缺点非常明显,由于是直接保存整个模型,保存的文件仅能用于特定的模型类,并且pickle在保存模型时,并不会保存模型类的定义,而是直接保存包含类定义的文件,所以加载模型时,原始保存的目录结构及模型定义都不能发生变化,否则会加载失败。但第一种方式则不会存在这种问题,因为第一种方式的本质是根据字典的键匹配来加载模型参数的。
 

保存通用检查点

通用检查点比较灵活,既可以用于推断,也可以用于继续训练模型,其保存命令如下:

torch.save({
	'epoch': epoch,
	'model_state_dict': model.state_dict(),
	'optimizer_state_dict': optimizer_state_dict(),
	'loss': loss,
	......
	}, PATH)

可以将后续可能用到的对象全部保存进字典。通用检查点的加载如下所示:

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
loss = checkpoint['loss']
epoch = checkpoint['epoch']
......

对于通用检查点,pytorch中一般以“.tar”后缀标识文件
 

保存多个模型进一个文件

这种场景与保存通用检查点相似,其核心就是将需要保存的信息封装到字典,然后保存字典。对于集成模型,经常就需要保存多个模型到一个文件,保存实例如下:

torch.save({
	'modelA_state_dict': modelA.state_dict(),
	'modelB_state_dict': modelB.state_dict(),
	'optimizerA_state_dict': optimizerA.state_dict(),
	'optimizerB_state_dict': optimizerB.state_dict(),
	......}, PATH)

加载参数,先初始化模型与优化器,然后使用torch.load加载字典,最后使用模型的load_state_dict方法按需加载就可以了。
 

加载其它模型的部分参数来暖启动模型

加载部分参数在使用预训练模型的情况下非常普遍,核心是设置load_state_dict方法中的strict关键字为False,这样忽略unexpected_keys与missing_keys

torch.save(modelA.state_dict(), PATH)
modelB.load_state_dict(torch.load(PATH), strict=False)

 

跨设备保存与加载模型

  • 在GPU上保存,CPU上加载:加载时需要指定torch.load方法的map_location关键词参数的信息为CPU,如下:

    torch.save(model.state_dict(), PATH)
    
    device = torch.device('cpu')
    model.load_state_dict(torch.load(PATH, map_location=device))
    
  • GPU上保存,GPU上加载:加载时需要指明社保信息,并且由于模型初始化的时候是位于CPU的,因此还需要将初始化的模型移动到GPU上。另外由于模型是在GPU上,因此对于模型的输入,也需要转移到GPU上。

    torch.save(model.state_dict(), PATH)
    
    device = torch.device('cuda:0')
    model.load_state_dict(torch.load(PATH, map_location=device))
    model.to(device)
    
  • 保存torch.nn.DataParallel模型

    torch.save(model.module.state_dict(), PATH)
    # Load to whatever device you want
    

 

参考资料

torch.save

torch.load

torch.nn.Module

SAVING AND LOADING MODELS

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值