glip修改二

增加设置迭代步数和提前停止训练逻辑
  1. 修改“tools/train_net.py”脚本。

    增加args.early_stop_iteration参数,将参数传入训练主函数。

    修改前:

    def train(cfg, local_rank, distributed, use_tensorboard=False,):
    ……
        do_train(
            cfg,
            model,
            data_loader,
            optimizer,
            scheduler,
            checkpointer,
            device,
            checkpoint_period,
            arguments,
            data_loaders_val,
            meters
        )
    
    return model
    
    def main():
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
    ……
    parser.add_argument("--override_output_dir", default=None)
    
    args = parser.parse_args()
    ……
        model = train(cfg=cfg,
                      local_rank=args.local_rank,
                      distributed=args.distributed,
                      use_tensorboard=args.use_tensorboard)

    修改后:

    def train(cfg, local_rank, distributed, use_tensorboard=False, early_stop_iteration=-1):
    ……
        do_train(
            cfg,
            model,
            data_loader,
            optimizer,
            scheduler,
            checkpointer,
            device,
            checkpoint_period,
            arguments,
            data_loaders_val,
            meters,
            early_stop_iteration=early_stop_iteration,
        )
    
    return model
    
    def main():
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
    ……
    parser.add_argument("--override_output_dir", default=None)
    parser.add_argument("--early_stop_iteration", type=int, default=-1)
    
    args = parser.parse_args()
    ……
    model = train(cfg=cfg,
                  local_rank=args.local_rank,
                  distributed=args.distributed,
                  use_tensorboard=args.use_tensorboard,
                  early_stop_iteration=args.early_stop_iteration)
  2. 修改“maskrcnn_benchmark/engine/trainer.py”脚本。

    将early_stop_iteration参数传入训练函数。

    修改前:

    def do_train(
            cfg,
            model,
            data_loader,
            optimizer,
            scheduler,
            checkpointer,
            device,
            checkpoint_period,
            arguments,
            val_data_loader=None,
            meters=None,
            zero_shot=False
    ):
    ……
    for iteration, (images, targets, idxs, positive_map, positive_map_eval, greenlight_map) in enumerate(data_loader, start_iter):
        ……
            arguments["iteration"] = iteration
    
    images = images.to(device)
    ……
            if val_data_loader and (iteration % checkpoint_period == 0 or iteration == max_iter):
                if is_main_process():
                    print("Evaluating")
                ……

    修改后:

    def do_train(
            cfg,
            model,
            data_loader,
            optimizer,
            scheduler,
            checkpointer,
            device,
            checkpoint_period,
            arguments,
            val_data_loader=None,
            meters=None,
            zero_shot=False,
            early_stop_iteration=-1,
    ):
    ……
    for iteration, (images, targets, idxs, positive_map, positive_map_eval, greenlight_map) in enumerate(data_loader, start_iter):
        ……
            arguments["iteration"] = iteration
            if early_stop_iteration > 0:
                if iteration == early_stop_iteration + 1:
                    break
    
    images = images.to(device)
    ……
    
            if val_data_loader and (iteration % checkpoint_period == 0 or iteration == max_iter or
                                    iteration == early_stop_iteration):
    if is_main_process():
                    print("Evaluating")
                ……
增加训练时计算实时fps的逻辑

修改“maskrcnn_benchmark/engine/trainer.py”脚本。

增加训练时计算实时fps的逻辑。

修改前:

def do_train(
        cfg,
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
        val_data_loader=None,
        meters=None,
        zero_shot=False,
        early_stop_iteration=-1,
):
……
for iteration, (images, targets, idxs, positive_map, positive_map_eval, greenlight_map) in enumerate(data_loader, start_iter):
    ……
    meters.update(time=batch_time, data=data_time)

修改后:

def do_train(
        cfg,
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
        val_data_loader=None,
        meters=None,
        zero_shot=False,
        early_stop_iteration=-1,
):
……
for iteration, (images, targets, idxs, positive_map, positive_map_eval, greenlight_map) in enumerate(data_loader, start_iter):
    ……
train_fps = cfg.SOLVER.IMS_PER_BATCH / batch_time
meters.update(time=batch_time, data=data_time, fps=train_fps)

原文链接:概述-模型开发-Ascend Extension for PyTorch6.0.RC2开发文档-昇腾社区 (hiascend.com)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值