姿态估计2-05:PVNet(6D姿态估计)--源码无死角解析(1)-训练代码总览

以下链接是个人关于PVNet(6D姿态估计) 所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:17575010159 相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。 文末附带 \color{blue}{文末附带} 文末附带 公众号 − \color{blue}{公众号 -} 公众号 海量资源。 \color{blue}{ 海量资源}。 海量资源

姿态估计2-00:PVNet(6D姿态估计)-目录-史上最新无死角讲解

train_net.py注释

下面是对train_net.py文件的注释,该代码十分的简单,所以注释也十分简洁:

from lib.config import cfg, args
from lib.networks import make_network
from lib.train import make_trainer, make_optimizer, make_lr_scheduler, make_recorder, set_lr_scheduler
from lib.datasets import make_data_loader
from lib.utils.net_utils import load_model, save_model, load_network
from lib.evaluators import make_evaluator
import torch.multiprocessing


def train(cfg, network):
    # 如果训练数据为City,这进行文件系统共享
    if cfg.train.dataset[:4] != 'City':
        torch.multiprocessing.set_sharing_strategy('file_system')
    # 制作训练器
    trainer = make_trainer(cfg, network)
    # 制作优化器
    optimizer = make_optimizer(cfg, network)
    # 制作学习率调整器
    scheduler = make_lr_scheduler(cfg, optimizer)
    #  用于记录信息
    recorder = make_recorder(cfg)
    # 用于评估
    evaluator = make_evaluator(cfg)

    # 进行模型加载
    begin_epoch = load_model(network, optimizer, scheduler, recorder, cfg.model_dir, resume=cfg.resume)
    # set_lr_scheduler(cfg, scheduler)

    # 创建训练以及评估数据集
    train_loader = make_data_loader(cfg, is_train=True, max_iter=cfg.ep_iter)
    val_loader = make_data_loader(cfg, is_train=False)
    # train_loader = make_data_loader(cfg, is_train=True, max_iter=100)

    # 循环进行迭代训练
    for epoch in range(begin_epoch, cfg.train.epoch):
        recorder.epoch = epoch
        # 进行一个epoch的迭代训练
        trainer.train(epoch, train_loader, optimizer, recorder)
        # 记录学习了一个epoch,并且根据预设定的参数,看是否需要对学习率进行更改
        scheduler.step()
        # 迭代到指定次数,保存好训练的
        if (epoch + 1) % cfg.save_ep == 0:
            save_model(network, optimizer, scheduler, recorder, epoch, cfg.model_dir)
        # 迭代到指定次数,进行评估训练
        if (epoch + 1) % cfg.eval_ep == 0:
            trainer.val(epoch, val_loader, evaluator, recorder)

    return network


def test(cfg, network):
    # 根据配置创建训练器
    trainer = make_trainer(cfg, network)
    # 创建数据迭代器
    val_loader = make_data_loader(cfg, is_train=False)
    # 创建评估器
    evaluator = make_evaluator(cfg)
    # 加载权重
    epoch = load_network(network, cfg.model_dir, resume=cfg.resume, epoch=cfg.test.epoch)
    # 进行评估
    trainer.val(epoch, val_loader, evaluator)


def main():
    # 根据配置参数,构建网路
    network = make_network(cfg)
    # 根据传入的参数选择测试或者训练
    if args.test:
        test(cfg, network)
    else:
        train(cfg, network)


if __name__ == "__main__":
    main()

总结

训练代码的套路基本都是差不多的,基本就是
1.解析参数
2.构建网络模型
3.加载训练测试数据集迭代器
4.迭代训练
5.模型评估保存

在这里插入图片描述

  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

江南才尽,年少无知!

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值