openpcdet-0.1 pointpillars train.py 注释
pointpillars
paper: link.
code: link.
def parge_config():
parser = argparse.ArgumentParser(description='arg parser')
parser.add_argument('--cfg_file', type=str, default=None, help='specify the config for training')
parser.add_argument('--data_dir', type=str, default=None)
parser.add_argument('--batch_size', type=int, default=16, required=False, help='batch size for training')
parser.add_argument('--epochs', type=int, default=80, required=False, help='number of epochs to train for')
parser.add_argument('--workers', type=int, default=4, help='number of workers for dataloader') #使用DataLoader 加载数据的线程数
parser.add_argument('--extra_tag', type=str, default='default', help='extra tag for this experiment')
parser.add_argument('--ckpt', type=str, default=None, help='checkpoint to start from')
parser.add_argument('--pretrained_model', type=str, default=None, help='pretrained_model')
parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none')
parser.add_argument('--tcp_port', type=int, default=18888, help='tcp port for distrbuted training')
parser.add_argument('--sync_bn', action='store_true', default=False, help='whether to use sync bn')
parser.add_argument('--fix_random_seed', action='store_true', default=False, help='whether to use sync bn')
parser.add_argument('--ckpt_save_interval', type=int, default=2, help='number of training epochs')
parser.add_argument('--local_rank', type=int, default=0, help='local rank for distributed training')
parser.add_argument('--max_ckpt_save_num', type=int, default=30, help='max number of saved checkpoint')
parser.add_argument('--set', dest='set_cfgs', default=None, nargs=argparse.REMAINDER,
help='set extra config keys if needed')
args = parser.parse_args()
cfg_from_yaml_file(args.cfg_file, cfg) #此函数是将输入的args.cfg参数赋值给cfg
cfg.TAG = Path(args.cfg_file).stem
"""
Path.
name: 目录的最后一个部分
suffix:目录中最后一个部分的扩展名
stem:目录最后一个部分,没有后缀
suffixes:返回多个扩展名列表
with_suffix(suffix):补充扩展名到尾部,扩展名存在无效
with_name(name):替换目录最后一个部分并返回一个新的路径
p = Path('/tmp/test.tar.gz')
>>> PosixPath('/tmp/test.tar.gz')
print(p.name)
>>>test.tar.gz
print(p.suffix)
>>>.gz
print(p.suffixes)
>>>['.tar', '.gz']
print('p.stem')
>>>test.tar
print(p.with_name('test2.tgz'))
>>>/tmp
p = Path('/tmp/README')
>>>PosixPath('/tmp/README')
print(p.with_suffix('.txt'))
>>>/tmp/README.txt
"""
if args.set_cfgs is not None:
cfg_from_list(args.set_cfgs, cfg) #通过列表设置config
return args, cfg
def main():
args, cfg = parge_config()
if args.launcher == 'none':
dist_train = False
else:
# 分布式 init_dist_%s为0 cfg.LOCAL_RANK=0 多GPU排序
args.batch_size, cfg.LOCAL_RANK = getattr(common_utils, 'init_dist_%s' % args.launcher)(
args.batch_size, args.tcp_port, args.local_rank, backend='nccl'
)
dist_train = True
if args.fix_random_seed:
common_utils.set_random_seed(666)
output_dir = cfg.ROOT_DIR / 'output' / cfg.TAG / args.extra_tag # 看args 输入参数设置 extra_tag == defaut
output_dir.mkdir(parents=True, exist_ok=True) # 设置output_dir为父目录
ckpt_dir = output_dir / 'ckpt'
ckpt_dir.mkdir(parents=True, exist_ok=True)
#日志输出文件 年月日时分秒
log_file = output_dir / ('log_train_%s.txt' % datetime.datetime.now().strftime('%Y%m%d-%H%M%S'))
logger = common_utils.create_logger(log_file, rank=cfg.LOCAL_RANK) #创建日志文件
# log to file 开始训练的日志
logger.info('**********************Start logging**********************')
gpu_list = os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ.keys() else 'ALL'
logger.info('CUDA_VISIBLE_DEVICES=%s' % gpu_list)
if dist_train: #分布式训练
total_gpus = dist.get_world_size()
logger.info('total_batch_size: %d' % (total_gpus * args.batch_size))
for key, val in vars(args).items(): # 输入参数的值 item()方法把字典中每对key和value组成一个元组,并把这些元组放在列表中返回。
logger.info('{:16} {}'.format(key, val))
log_config_to_file(cfg, logger=logger)
tb_log = SummaryWriter(log_dir=str(output_dir / 'tensorboard')) if cfg.LOCAL_RANK == 0 else None
# -----------------------create dataloader & network & optimizer---------------------------
# 输入数据 参数: 查找数据目录 batch_size 分布训练与否? 使用DataLoader加载数据的线程数 日志文件输出 是否是训练集
train_set, train_loader, train_sampler = build_dataloader(
cfg.DATA_CONFIG.DATA_DIR, args.batch_size, dist_train, workers=args.workers, logger=logger, training=True
)
model = build_network(train_set)
if args.sync_bn: # 分布式训练
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model.cuda()
# 优化损失函数
optimizer = build_optimizer(model, cfg.MODEL.TRAIN.OPTIMIZATION)
# load checkpoint if it is possible
start_epoch = it = 0
last_epoch = -1
if args.pretrained_model is not None:
model.load_params_from_file(filename=args.pretrained_model, to_cpu=dist, logger=logger)
if args.ckpt is not None: #从某次训练的地方开始
it, start_epoch = model.load_params_with_optimizer(args.ckpt, to_cpu=dist, optimizer=optimizer, logger=logger)
last_epoch = start_epoch + 1
else:
# glob.glob()返回所有匹配的文件路径列表。它只有一个参数pathname,定义了文件路径匹配规则,这里可以是绝对路径,也可以是相对路径。
ckpt_list = glob.glob(str(ckpt_dir / '*checkpoint_epoch_*.pth'))
if len(ckpt_list) > 0:
ckpt_list.sort(key=os.path.getmtime)
it, start_epoch = model.load_params_with_optimizer(
ckpt_list[-1], to_cpu=dist, optimizer=optimizer, logger=logger
)
last_epoch = start_epoch + 1
model.train() # before wrap to DistributedDataParallel to support fixed some parameters
if dist_train:
model = nn.parallel.DistributedDataParallel(model, device_ids=[cfg.LOCAL_RANK % torch.cuda.device_count()])
logger.info(model) #模型输出
# lr_scheduler 学习率 lr_warmup_scheduler关于warmup学习率
lr_scheduler, lr_warmup_scheduler = build_scheduler(
optimizer, total_iters_each_epoch=len(train_loader), total_epochs=args.epochs,
last_epoch=last_epoch, optim_cfg=cfg.MODEL.TRAIN.OPTIMIZATION
)
# -----------------------start training---------------------------
logger.info('**********************Start training %s(%s)**********************' % (cfg.TAG, args.extra_tag))
train_model(
model,
optimizer,
train_loader,
model_func=model_fn_decorator(),
lr_scheduler=lr_scheduler,
optim_cfg=cfg.MODEL.TRAIN.OPTIMIZATION,
start_epoch=start_epoch,
total_epochs=args.epochs,
start_iter=it,
rank=cfg.LOCAL_RANK,
tb_log=tb_log,
ckpt_save_dir=ckpt_dir,
train_sampler=train_sampler,
lr_warmup_scheduler=lr_warmup_scheduler,
ckpt_save_interval=args.ckpt_save_interval,
max_ckpt_save_num=args.max_ckpt_save_num
)
logger.info('**********************End training**********************')