一个简洁、好用的Pytorch训练模板

一个简洁、好用的Pytorch训练模板

代码地址:https://github.com/KinglittleQ/Pytorch-Template

怎么使用

1) 更改template.py

替换 __init__方法中的内容,增添自己的模型、优化器、评估器等等.

class Model():

    def __init__(self, args):
        self.writer = tX.SummaryWriter(log_dir=None, comment='')
        self.train_logger = None  # not neccessary
        self.eval_logger = None  # not neccessary
        self.args = args  # not neccessary

        self.step = 0
        self.epoch = 0
        self.best_error = float('Inf')

        self.model = None
        self.optimizer = None
        self.criterion = None
        self.metric = None

        self.train_loader = None
        self.test_loader = None

        self.device = None

        self.ckpt_dir = None
        self.log_per_step = None

2) 写部分训练代码

你所需要做的只是写一个简单的for循环:

model = Model()

for epoch in range(n_epochs):
    model.train()
    if (epoch + 1) % eval_per_epoch == 0:
        model.eval()

print('Done!!!')

3) 继续训练

继续训练十分方便,只需要加载之前保存好的模型。

model = Model()
if model_path:
    model.load_state(model_path)

for i in range(n_epochs):
    model.train()
    if model.epoch % eval_per_epoch == 0:
        model.eval()

Example

  • LeNet: 训练一个LeNet对MNIST手写数字进行分类

    • 训练过程如下:

      ......
      epoch 1 step 3400   loss 0.0434
      epoch 1 step 3500   loss 0.0331
      epoch 1 step 3600   loss 0.00188
      epoch 1 step 3700   loss 0.00341
      save model at ../models\best.pth.tar
      save model at ../models\1.pth.tar
      epoch 1 error 0.0237
      epoch 2 step 3800   loss 0.0201
      epoch 2 step 3900   loss 0.00523
      epoch 2 step 4000   loss 0.0236
      ......
    • 使用tensorboard可视化输出:

      tensorboard --logdir example/LeNet/log

      1210583-20190104171222630-2091271303.png
      1210583-20190104171208132-1216698607.png

    • 继续训练

      load model from checkpoint/9.pth.tar
      epoch 10    step 33800  loss 0.000128
      epoch 10    step 33900  loss 6.64e-06
      epoch 10    step 34000  loss 0.000613
      epoch 10    step 34100  loss 2.41e-05
      ......

转载于:https://www.cnblogs.com/magic-girl/p/pytorch_template.html

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值