pyhson大数据处理_【pytorch】保存与加载模型(1.7.0官方教程翻译)

0 项目场景

pytorch训练完模型后,如何保存与加载?保存/加载有两种方式:一是保存/加载模型参数,二是保存/加载整个模型。

1 模型参数

保存/加载模型参数,官方推荐用这种方式,原因也给了:说这种方式对于日后恢复模型更具灵活性。

1.1 保存

torch.save(model.state_dict(), PATH)

state_dict里保存有模型的参数,PATH是保存路径,推荐.pt或.pth作为文件拓展名。

1.2 加载

model = TheModelClass(*args, **kwargs)

model.load_state_dict(torch.load(PATH))

model.eval()

TheModelClass是你定义的模型结构,PATH是保存路径。如果你的模型结构中含有dropout或batch normalization层,在测试之前一定要加上model.eval()(如果没有可以不加),不然会产生错误的输出结果。

2 整个模型

保存/加载整个模型,官方不推荐这种方式,原因也给了:说是在其它项目中使用或重构后,代码可能会中断。

2.1 保存

torch.save(model, PATH)

2.2 加载

# Model class must be defined somewhere

model = torch.load(PATH)

model.eval()

定义模型结构的类必须在代码中出现。这种保存/加载模型的方式从语法上来说更加简洁和直观,但是将模型引入其它项目中使用可能出错,所以只在自己的项目中使用应该没有问题,想将模型引入其它项目中使用还是推荐第一种保存/加载方式。

3 断点续训

顾名思义就是从上次没训练完的地方继续训练,这对高效训练来说具有重要意义。

3.1 保存

torch.save({

'epoch': epoch,

'model_state_dict': model.state_dict(),

'optimizer_state_dict': optimizer.state_dict(),

'loss': loss,

...

}, PATH)

其中model.state_dict()和optimizer.state_dict()是必须要保存的,因为这两项会随着模型的训练而更新。epoch和loss等是作为记录用的,能让你直观的了解到目前训练到第几轮了,损失是多少。PATH是保存路径,建议以.tar为文件拓展名。

3.2 加载

model = TheModelClass(*args, **kwargs)

optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)

model.load_state_dict(checkpoint['model_state_dict'])

optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

epoch = checkpoint['epoch']

loss = checkpoint['loss']

model.train()

# - or -

model.eval()

首先初始化模型和优化器,然后加载之前保存的模型和优化器参数。接着你可以选择从上一次结束的地方继续训练或者直接测试。继续训练的话加上model.train(),测试模型的话加上model.eval(),如果模型结构中没有dropout或batch normalization层,可以不加。

4 多个模型

有时候你可能需要将多个模型保存到一个文件中,比如GAN

4.1 保存

torch.save({

'modelA_state_dict': modelA.state_dict(),

'modelB_state_dict': modelB.state_dict(),

'optimizerA_state_dict': optimizerA.state_dict(),

'optimizerB_state_dict': optimizerB.state_dict(),

...

}, PATH)

4.2 加载

modelA = TheModelAClass(*args, **kwargs)

modelB = TheModelBClass(*args, **kwargs)

optimizerA = TheOptimizerAClass(*args, **kwargs)

optimizerB = TheOptimizerBClass(*args, **kwargs)

checkpoint = torch.load(PATH)

modelA.load_state_dict(checkpoint['modelA_state_dict'])

modelB.load_state_dict(checkpoint['modelB_state_dict'])

optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])

optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

modelA.train()

modelB.train()

# - or -

modelA.eval()

modelB.eval()

5. 迁移学习

有时候我们在训练一个新的模型B时可以用到已有的模型A的参数,比如迁移学习,这样就不用从头开始训了,模型可以很快的收敛,大大地提高了训练效率。

5.1 保存

torch.save(modelA.state_dict(), PATH)

5.2 加载

modelB = TheModelBClass(*args, **kwargs)

modelB.load_state_dict(torch.load(PATH), strict=False)

strict=False:模型A和模型B是不完全一样的,模型B训练的时候可能只需要A中一部分值,其它不要的值就丢掉,设置strict=False就是为了匹配需要的那部分值,忽略不需要的那部分值。

6 关于设备

如何在不同的设备,比如CPU或GPU上,保存与加载模型?

6.1 GPU保存 & CPU加载

模型在GPU上训练,但想把它加载到CPU上时,用这种方式

6.1.1 GPU保存

torch.save(model.state_dict(), PATH)

6.1.2 CPU加载

device = torch.device('cpu')

model = TheModelClass(*args, **kwargs)

model.load_state_dict(torch.load(PATH, map_location=device))

# - or -

model = TheModelClass(*args, **kwargs)

model.load_state_dict(torch.load(PATH))

6.2 GPU保存 & GPU加载

模型在GPU上训练,想把它加载到GPU上时,用这种方式

6.2.1 GPU保存

torch.save(model.state_dict(), PATH)

6.2.2 GPU加载

device = torch.device("cuda")

model = TheModelClass(*args, **kwargs)

model.load_state_dict(torch.load(PATH))

model.to(device)

# Make sure to call input = input.to(device) on any input tensors that you feed to the model

6.3 CPU保存 & CPU加载

模型在CPU上训练,想把它加载到CPU上时,用这种方式

6.3.1 CPU保存

torch.save(model.state_dict(), PATH)

6.3.2 CPU加载

model = TheModelClass(*args, **kwargs)

model.load_state_dict(torch.load(PATH))

6.4 CPU保存 & GPU加载

模型在CPU上训练,想把它加载到GPU上时,用这种方式

6.4.1 CPU保存

torch.save(model.state_dict(), PATH)

6.4.2 GPU加载

device = torch.device("cuda")

model = TheModelClass(*args, **kwargs)

model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want

model.to(device)

# Make sure to call input = input.to(device) on any input tensors that you feed to the model

7 引用参考

https://pytorch.org/tutorials/beginner/saving_loading_models.html

本文分享 CSDN - Xavier Jiezou。

如有侵权,请联系 support@oschina.cn 删除。

本文参与“OSC源创计划”,欢迎正在阅读的你也加入,一起分享。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值