Pytorch训练流程

调试了很久YOLO的标准Pytorch实现,将Train代码进行了重新整理,抽出了标准的Pytorch训练框架。现在整理下来,以备后用。整理后的代码分为三个大部分,每个大部分分为三个小部分:

1、初始化(Init):训练之前先分别创建Model、Dataset&Dataloader、Optimizer;

2、轮次内部(Epoch):分别进行:Dataloader遍历训练、Save模型(间隔)、Eval模型(间隔);

3、训练(Train):其实隶属于Epoch中的Dataloader遍历,最核心的训练步骤:Forward、Backward、Optimize参数;

官方YOLO的Pytorch训练代码整理以后,再简化之后就是下面这样。

其中一些小地方需要注意,例如:在模型进行训练之前,一定要调成训练模式,评估时要调成评估模式,以固定BN层和Dropout层的参数。优化器在定义时要指定需要优化的模型参数。封装输入图像和标签时,标签不需要梯度。优化器使用之后需要清零。

其他注意事项:按照惯例,一些项目上的设定参数都是需要通过argparse传入工程的,为了项目的清晰,我把全部的工程参数设定放到了"__main__"部分,核心的训练部分做为一个独立的函数存在于文件中,这样的安排可以增加代码的可读性,方便整理。

def Quan_train(opt, logger):
    ### Init Step 1: Create Model
    model, device, start_epoch = create_model(opt)

    ### Init Step 2: Create Dataset
    dataloader, train_path, valid_path, class_names = create_dataset(opt)

    ### Init Step 3: Create Optimizer
    optimizer = torch.optim.Adam(model.parameters())

    # Epoch
    for epoch in range(start_epoch, opt.epochs):
        # Set model in train.
        model.train()

        ### Epoch Step 1: Train
        for batch_i, (_, imgs, targets) in enumerate(dataloader):
            batches_done = len(dataloader) * epoch + batch_i

            # Load input and target
            imgs = Variable(imgs.to(device))
            targets = Variable(targets.to(device), requires_grad=False)

            ### Train Step 1: Forward pass, get loss
            loss, outputs = model(imgs, targets)

            ### Train Step 2: Backward pass, get gradient
            loss.backward()

            ### Train Step 3: Optimize params
            if batches_done % opt.gradient_accumulations:  # Accumulates gradient before each step
                optimizer.step()
                optimizer.zero_grad()

        ### Epoch Step 2: Save
        if epoch % opt.checkpoint_interval == 0:
            torch.save(model.state_dict(), f"checkpoints/yolov3-tiny_quan_ckpt_%d.pth" % epoch)

        ### Epoch Step 3: Eval
        if epoch % opt.evaluation_interval == 0:
            print("\n---- Evaluating Model ----")
            # Evaluate the model on the validation set
            precision, recall, AP, f1, ap_class, IoU_total = evaluate()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Super-Params
    parser.add_argument("--epochs", type=int, default=100, help="number of epochs")
    parser.add_argument("--batch_size", type=int, default=64, help="size of each image batch")
    parser.add_argument("--img_size", type=int, default=416, help="size of each image dimension")
    parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
    parser.add_argument("--gradient_accumulations", type=int, default=2, help="number of gradient accums before step")
    # ......Other Params
    opt = parser.parse_args()

    # Set Logger
    logger = Logger("logs")

    # Set env GPU
    os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu)

    # Train
    Quan_train(opt, logger)

 

  • 0
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值