pytorch迁移学习,模型加载保存,参数量计算一揽子使用技巧

查看模型参数

for name, param in mpdel.named_parameters():
    if param.requires_grad:
        print(name, ':', param.size())

计算模型参数量

total = sum([param.nelement() for param in model.parameters()])
print("Number of parameter: %.2fM" % (total/1e6))

迁移学习

https://blog.csdn.net/u011268787/article/details/80170482?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-3.control&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-3.control

EarlyStopping的使用

from pytorchtools import EarlyStopping

early_stopping = EarlyStopping(patience=patience, verbose=True)     
for batch, data in enumerate(dataloader):
	# 计算train_loss和val_loss
	early_stopping(val_loss, model)
	if early_stopping.early_stop:
	  print("Early stopping")
	  	break

模型的保存和加载

训练过程中,有时候会由于各种原因停止训练,这时候在训练过程中就需要注意将每一个epoch的模型进行保存以便下次接着训练。

模型保存
# 标准
torch.save(model.state_dict(), 'checkpoint.pt')
# 如果是多GPU
torch.save(model.module.state_dict(), 'checkpoint.pt')
# 如果需要保存多个参数
state = {'epoch': epoch + 1, #保存的当前轮数
         'state_dict': model.state_dict(), # 训练好的参数
         'optimizer': optimizer.state_dict(), # 优化器参数,为了后续的resume
         'best_pred': best_pred # 当前最好的精度
          ,....,...}
torch.save(state, 'checkpoint.pt')
模型加载
# 标准
model.load_state_dict(torch.load('checkpoint.pt'))

# 多参数的模型加载
checkpoint = torch.load('checkpiont.pt')
model.load_state_dict(checkpoint['state_dict']) # 模型参数
optimizer.load_state_dict(checkpoint['optimizer']) # 优化参数
epoch = checkpoint['epoch'] # epoch,可以用于更新学习率等
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值