深度学习如何恢复训练?中断的训练如何接着之前保存的 ckpt 参数继续训练?Pytorch-Lightning Trainer

在做的实验基础代码是用的 Pytorch-Lightning 中的训练器 Trainer 进行训练

  1. 首先需要保存的训练后的模型参数,保存 checkpoint 断点
checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath=args.ckpt_dir + "/" + args.model_type,
        filename=model_savename + "---{epoch}---" + dt_string +'-'+str(args.use_img)+str(args.use_att)+str(args.use_date)+str(args.use_trends)+'RNN3_5',#str(note)
        monitor="val_mae",
        mode="min",#这里实验效果是越小越好,所以是“min”
        save_top_k=5,#1
    )
    
print(checkpoint_callback.best_model_path)#打印出效果最好的模型参数存储的路径

这里保存了效果前五的模型,这里的实验效果是越小越好,并打印出效果最好的模型参数存储路径

  1. 在训练器 Trainer 里加载之前保存的最佳模型
 trainer.fit(model, train_dataloaders=trainloader, val_dataloaders=testloader,ckpt_path='自己替换成最佳模型参数所存在的路径.ckpt')

主要是 trainer.fit() 函数里,ckpt_path 参数所提供的效果,输入 ckpt 文件路径(从这里文件恢复训练)

参考博客:https://blog.csdn.net/qq_27135095/article/details/122635743?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522167583461916800180668936%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fall.%2522%257D&request_id=167583461916800180668936&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2allfirst_rank_ecpm_v1~rank_v31_ecpm-1-122635743-null-null.142%5Ev73%5Econtrol,201%5Ev4%5Eadd_ask,239%5Ev1%5Einsert_chatgpt&utm_term=pl%20trainer%20%E6%98%AF%E5%A6%82%E4%BD%95%E8%AE%AD%E7%BB%83%E7%9A%84&spm=1018.2226.3001.4187

  • 3
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值