更改代码使其能够在colab训练时自动保存,自动读取。

其主要改动有:

  1. 在config里增加了两个参数,一个是参数epoch_save_train_inf_for_colab,其作用是每隔多少个epoch保存一次信息;另一个参数是save_model_path_colab就是保存信息的文件夹信息
  2. 增加读取模型信息的代码
  3. 增加保存模型信息的代码

1.增加参数信息

在训练脚本的代码中,增加了读取这两个参数的代码。(在后面会将整个训练脚本的代码放出)

    epoch_save_train_inf_for_colab = config['epoch_save_train_inf_for_colab']  # 每隔多少次epoch 从colab上保存训练信息
    save_model_path_colab = confin['save_model_path_colab']

2. 增加读取模型信息代码

在这里使用的一个判断就是判断是否模型信息的文件存在,若存在则读取,若不存在则赋予相关变量为初始值。

    # 载入 之前colab训练好的参数,权重
    if os.path.exists(save_model_path_colab+'/crnn-model'+".t7"):
        save_path=save_model_path_colab+'/crnn-model'+".t7"
        checkpoint = torch.load(save_path)
        crnn.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']

        tot_train_count_start  = checkpoint['tot_train_count']
        tot_train_loss_start = checkpoint['tot_train_loss']
        num_i_start = checkpoint['tot_train_i']


        print("当前epoch:{} 当前总的训练损失:{:.4f} 当前训练了多少次:{:}".format(tot_train_count_start+1,tot_train_loss_start,num_i_start))

    else:

        start_epoch = 1
        tot_train_count_start  = 0
        tot_train_loss_start = 0.
        num_i_start = 1
        print("当前epoch:{} 当前总的训练损失:{:.4f} 当前训练了多少次:{:}".format(tot_train_count_start+1,tot_train_loss_start,num_i_start))

3. 增加保存模型信息代码

在epoch的循环中,

 if epoch % epoch_save_train_inf_for_colab == 0:

            # 保存训练进度
            state = {
    
                        'model': crnn.state_dict(), 
                        'optimizer':optimizer.state_dict(), 
                        'epoch': epoch,
                        #    'train_loss':epoch_loss,
                        #    'train_acc':epoch_acc,
                        'tot_train_loss':tot_train_loss,
                        'tot_train_count':tot_train_count,
                        'tot_train_i':i
                        }
            # save_model_path_colab="/content/drive/My Drive/colab notebooks/output/"   
            # save_model_path_colab_new=os.path.join(save_model_path_colab,
            #                                         f'{epoch}_{i:06}_loss{loss}.pt')
            torch.save(state,save_model_path_colab+'/crnn-model'+".t7")
            print('save train model at ', save_model_path_colab+'/crnn-model'+".t7","from colab")

附上整体训练代码

import os

  • 0
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值