pytorch保存模型pth_PyTorch模型的保存与加载简单总结

本文介绍了PyTorch中模型的保存与加载方法,包括torch.save()、torch.load()、Module.state_dict()和Module.load_state_dict()的使用。详细阐述了如何保存和加载整个模型、state_dict,以及如何处理不同模型间部分参数匹配的情况。此外,还提到了加载预训练模型的部分和解决CUDA内存不足的问题。
摘要由CSDN通过智能技术生成

torch.save()和torch.load():

torch.save()和torch.load()配合使用,

分别用来保存一个对象(任何对象,

不一定要是PyTorch中的对象)到文件,和从文件中加载一个对象.

加载的时候可以指明是否需要数据在CPU和GPU中相互移动.

Module.state_dict()和Module.load_state_dict():

Module.state_dict()返回一个字典,

该字典以键值对的方式保存了Module的整个状态.

Module.load_state_dict()可以从一个字典中加载参数到这个module和其后代,

如果strict是True,

那么所加载的字典和该module本身state_dict()方法返回的关键字必须严格确切的匹配上.

If strict is True,

then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

返回值是一个命名元组:

NamedTuple with missing_keys and unexpected_keys fields,

分别保存缺失的关键字和未预料到的关键字.

如果自己的模型跟预训练模型只有部分层是相同的,

那么可以只加载这部分相同的参数,

只要设置strict参数为False来忽略那些没有匹配到的keys即可。

# 方式1:

# model_path = 'model_name.pth'

# model_params_path = 'params_name.pth'

# ----保存----

# torch.save(model, model_path)

# ----加载----

# model = torch.load(model_path)

# 方式2:

#----保存----

# torch.save(model.state_dict(), model_params_path) #保存的文件名后缀一般是.pt或.pth

#----加载----

# model=Model().cuda() #定义模型结构

# model.load_state_dict(torch.load(model_params_path)) #加载模型参数

说明:

# 保存/加载整个模型

torch.save(model, PATH)

model = torch.load(PATH)

model.eval()

这种保存/加载模型的过程使用了最直观的语法,

所用代码量少。这使用Python的pickle保存所有模块。

这种方法的缺点是,保存模型的时候,

序列化的数据被绑定到了特定的类和确切的目录。

这是因为pickle不保存模型类本身,而是保存这个类的路径,

并且在加载的时候会使用。因此,

当在其他项目里使用或者重构的时候,加载模型的时候会出错。

# 保存/加载 state_dict(推荐)

torch.save(model.state_dict(), PATH)

model = TheModelClass(*args, **kwargs)

model.load_state_dict(torch.load(PATH))

model.eval()

自己选择要保存的参数,设置checkpoint:

#----保存----

torch.save({

'epoch': epoch + 1,

'arch': args.arch,

'state_dict': model.state_dict(),

'optimizer_state_dict': optimizer.state_dict(),

'loss': loss,

'best_prec1': best_prec1,},

'checkpoint_name.tar' )

#----加载----

checkpoint = torch.load('checkpoint_name.tar')

#按关键字获取保存的参数

start_epoch = checkpoint['epoch']

best_prec1 = checkpoint['best_prec1']

state_dict=checkpoint['state_dict']

model=Model()#定义模型结构

model.load_state_dict(state_dict)

保存多个模型到同一个文件:

#----保存----

torch.save({

'modelA_state_dict': modelA.state_dict(),

'modelB_state_dict': modelB.state_dict(),

'optimizerA_state_dict': optimizerA.state_dict(),

'optimizerB_state_dict': optimizerB.state_dict(),

...

}, PATH)

#----加载----

modelA = TheModelAClass(*args, **kwargs)

modelB = TheModelAClass(*args, **kwargs)

optimizerA = TheOptimizerAClass(*args, **kwargs)

optimizerB = TheOptimizerBClass(*args, **kwargs)

checkpoint = torch.load(PATH)

modelA.load_state_dict(checkpoint['modelA_state_dict']

modelB.load_state_dict(checkpoint['modelB_state_dict']

optimizerA.load_state_dict(checkpoint['optimizerA_state_dict']

optimizerB.load_state_dict(checkpoint['optimizerB_state_dict']

modelA.eval()

modelB.eval()

# or

modelA.train()

modelB.train()

# 在这里,保存完模型后加载的时候有时会

# 遇到CUDA out of memory的问题,

# 我google到的解决方法是加上map_location=‘cpu’

checkpoint = torch.load(PATH,map_location='cpu')

加载预训练模型的部分:

resnet152 = models.resnet152(pretrained=True) #加载模型结构和参数

pretrained_dict = resnet152.state_dict()

"""加载torchvision中的预训练模型和参数后通过state_dict()方法提取参数

也可以直接从官方model_zoo下载:

pretrained_dict = model_zoo.load_url(model_urls['resnet152'])"""

model_dict = model.state_dict()

# 将pretrained_dict里不属于model_dict的键剔除掉

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

# 更新现有的model_dict

model_dict.update(pretrained_dict)

# 加载我们真正需要的state_dict

model.load_state_dict(model_dict)

或者写详细一点:

model_dict = model.state_dict()

state_dict = {}

for k, v in pretrained_dict.items():

if k in model_dict.keys():

# state_dict.setdefault(k, v)

state_dict[k] = v

else:

print("Missing key(s) in state_dict :{}".format(k))

model_dict.update(state_dict)

model.load_state_dict(model_dict)

本文同步分享在 博客“敲代码的小风”(CSDN)。

如有侵权,请联系 support@oschina.cn 删除。

本文参与“OSC源创计划”,欢迎正在阅读的你也加入,一起分享。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值