openpcdet-0.1 pointpillars train.py

openpcdet-0.1 pointpillars train.py 注释


pointpillars
paper: link.
code: link.


def parge_config():
    parser = argparse.ArgumentParser(description='arg parser')
    parser.add_argument('--cfg_file', type=str, default=None, help='specify the config for training')

    parser.add_argument('--data_dir', type=str, default=None)
    parser.add_argument('--batch_size', type=int, default=16, required=False, help='batch size for training')
    parser.add_argument('--epochs', type=int, default=80, required=False, help='number of epochs to train for')
    parser.add_argument('--workers', type=int, default=4, help='number of workers for dataloader')  #使用DataLoader 加载数据的线程数
    parser.add_argument('--extra_tag', type=str, default='default', help='extra tag for this experiment')
    parser.add_argument('--ckpt', type=str, default=None, help='checkpoint to start from')
    parser.add_argument('--pretrained_model', type=str, default=None, help='pretrained_model')
    parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none')
    parser.add_argument('--tcp_port', type=int, default=18888, help='tcp port for distrbuted training')
    parser.add_argument('--sync_bn', action='store_true', default=False, help='whether to use sync bn')
    parser.add_argument('--fix_random_seed', action='store_true', default=False, help='whether to use sync bn')
    parser.add_argument('--ckpt_save_interval', type=int, default=2, help='number of training epochs')
    parser.add_argument('--local_rank', type=int, default=0, help='local rank for distributed training')
    parser.add_argument('--max_ckpt_save_num', type=int, default=30, help='max number of saved checkpoint')
    parser.add_argument('--set', dest='set_cfgs', default=None, nargs=argparse.REMAINDER,
                        help='set extra config keys if needed')

    args = parser.parse_args()

    cfg_from_yaml_file(args.cfg_file, cfg)  #此函数是将输入的args.cfg参数赋值给cfg
    cfg.TAG = Path(args.cfg_file).stem
    """
    Path.
    name: 目录的最后一个部分
    suffix:目录中最后一个部分的扩展名
    stem:目录最后一个部分,没有后缀
    suffixes:返回多个扩展名列表
    with_suffix(suffix):补充扩展名到尾部,扩展名存在无效
    with_name(name):替换目录最后一个部分并返回一个新的路径
     p = Path('/tmp/test.tar.gz')
    >>> PosixPath('/tmp/test.tar.gz')
    print(p.name)
    >>>test.tar.gz
    print(p.suffix)
    >>>.gz
    print(p.suffixes)
    >>>['.tar', '.gz']
    print('p.stem')
    >>>test.tar
    print(p.with_name('test2.tgz'))
    >>>/tmp
    p = Path('/tmp/README')
    >>>PosixPath('/tmp/README')
    print(p.with_suffix('.txt'))
    >>>/tmp/README.txt
    """

    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs, cfg)   #通过列表设置config

    return args, cfg


def main():
    args, cfg = parge_config()
    if args.launcher == 'none':
        dist_train = False
    else:
        #  分布式  init_dist_%s为0  cfg.LOCAL_RANK=0GPU排序
        args.batch_size, cfg.LOCAL_RANK = getattr(common_utils, 'init_dist_%s' % args.launcher)(
            args.batch_size, args.tcp_port, args.local_rank, backend='nccl'
        )
        dist_train = True
    if args.fix_random_seed:
        common_utils.set_random_seed(666)

    output_dir = cfg.ROOT_DIR / 'output' / cfg.TAG / args.extra_tag   # 看args 输入参数设置 extra_tag == defaut
    output_dir.mkdir(parents=True, exist_ok=True)  # 设置output_dir为父目录
    ckpt_dir = output_dir / 'ckpt'
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    #日志输出文件  年月日时分秒
    log_file = output_dir / ('log_train_%s.txt' % datetime.datetime.now().strftime('%Y%m%d-%H%M%S'))
    logger = common_utils.create_logger(log_file, rank=cfg.LOCAL_RANK)  #创建日志文件

    # log to file  开始训练的日志
    logger.info('**********************Start logging**********************')
    gpu_list = os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ.keys() else 'ALL'
    logger.info('CUDA_VISIBLE_DEVICES=%s' % gpu_list)

    if dist_train:  #分布式训练
        total_gpus = dist.get_world_size()
        logger.info('total_batch_size: %d' % (total_gpus * args.batch_size))
    for key, val in vars(args).items(): # 输入参数的值  item()方法把字典中每对key和value组成一个元组,并把这些元组放在列表中返回。
        logger.info('{:16} {}'.format(key, val))
    log_config_to_file(cfg, logger=logger)

    tb_log = SummaryWriter(log_dir=str(output_dir / 'tensorboard')) if cfg.LOCAL_RANK == 0 else None

    # -----------------------create dataloader & network & optimizer---------------------------
    #  输入数据  参数: 查找数据目录  batch_size   分布训练与否?  使用DataLoader加载数据的线程数  日志文件输出  是否是训练集
    train_set, train_loader, train_sampler = build_dataloader(
        cfg.DATA_CONFIG.DATA_DIR, args.batch_size, dist_train, workers=args.workers, logger=logger, training=True
    )

    model = build_network(train_set)
    if args.sync_bn:    #  分布式训练
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model.cuda()
    # 优化损失函数
    optimizer = build_optimizer(model, cfg.MODEL.TRAIN.OPTIMIZATION)

    # load checkpoint if it is possible
    start_epoch = it = 0
    last_epoch = -1
    if args.pretrained_model is not None:
        model.load_params_from_file(filename=args.pretrained_model, to_cpu=dist, logger=logger)

    if args.ckpt is not None:  #从某次训练的地方开始
        it, start_epoch = model.load_params_with_optimizer(args.ckpt, to_cpu=dist, optimizer=optimizer, logger=logger)
        last_epoch = start_epoch + 1
    else:
        # glob.glob()返回所有匹配的文件路径列表。它只有一个参数pathname,定义了文件路径匹配规则,这里可以是绝对路径,也可以是相对路径。
        ckpt_list = glob.glob(str(ckpt_dir / '*checkpoint_epoch_*.pth'))
        if len(ckpt_list) > 0:
            ckpt_list.sort(key=os.path.getmtime)
            it, start_epoch = model.load_params_with_optimizer(
                ckpt_list[-1], to_cpu=dist, optimizer=optimizer, logger=logger
            )
            last_epoch = start_epoch + 1

    model.train()  # before wrap to DistributedDataParallel to support fixed some parameters
    if dist_train:
        model = nn.parallel.DistributedDataParallel(model, device_ids=[cfg.LOCAL_RANK % torch.cuda.device_count()])
    logger.info(model)  #模型输出
    # lr_scheduler  学习率   lr_warmup_scheduler关于warmup学习率
    lr_scheduler, lr_warmup_scheduler = build_scheduler(
        optimizer, total_iters_each_epoch=len(train_loader), total_epochs=args.epochs,
        last_epoch=last_epoch, optim_cfg=cfg.MODEL.TRAIN.OPTIMIZATION
    )

    # -----------------------start training---------------------------
    logger.info('**********************Start training %s(%s)**********************' % (cfg.TAG, args.extra_tag))
    train_model(
        model,
        optimizer,
        train_loader,
        model_func=model_fn_decorator(),
        lr_scheduler=lr_scheduler,
        optim_cfg=cfg.MODEL.TRAIN.OPTIMIZATION,
        start_epoch=start_epoch,
        total_epochs=args.epochs,
        start_iter=it,
        rank=cfg.LOCAL_RANK,
        tb_log=tb_log,
        ckpt_save_dir=ckpt_dir,
        train_sampler=train_sampler,
        lr_warmup_scheduler=lr_warmup_scheduler,
        ckpt_save_interval=args.ckpt_save_interval,
        max_ckpt_save_num=args.max_ckpt_save_num
    )

    logger.info('**********************End training**********************')

不对的地方,请留言,谢谢。

  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值