前言
研究方向是目标检测,最近找到该开源库,测试过其性能还是不错的,其文章链接为点击下载,github链接为点击查看,注意新版的代码链接,已经支持到了Pytorch1.5,但是截止现在,最新版的代码编译会有个错误,这个还是个开放问题,还没有解决,所以我这里用的是我前两个月下载的一个旧版本,支持Pytorch1.1+的代码库,我昨天使用Pytorch1.5编译失败也是因为以下原因,所以,我使用了Pytorch1.3+CUDA10.1配置的环境。
源码解读
这里从train.py开始介绍,一点点的阅读源码,有助于后续的代码改进
train.py里面主要包含了两个函数parse_args()和main()
parse_args()函数,在这个配置函数里面,大部分的配置信息我们都是可以理解的,其实比较困难的是关于分布式训练的一些内容,推荐大家阅读下这里
# 这个方法主要是配置一些实验配置参数
def parse_args():
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('config', help='train config file path')#加载配置文件
parser.add_argument('--work_dir', help='the dir to save logs and models')#运行结果储存的位置
parser.add_argument(
'--resume_from', help='the checkpoint file to resume from')#断点续训的复训文件夹
parser.add_argument(
'--validate',
action='store_true',
help='whether to evaluate the checkpoint during training')#是否开启验证
parser.add_argument(
'--gpus',
type=int,
default=1,
help='number of gpus to use '
'(only applicable to non-distributed training)') #声明GPU的数目
parser.add_argument('--seed', type=int, default=None, help='random seed')#设置随机种子
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')# 是否为CUDNN后端设置确定性选项
# 采用哪种分布式训练模式 torch.distributed.init_process_group()
# 分布式多进程初始化时使用
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'--autoscale-lr',
action='store_true',
help='automatically scale lr with the number of gpus')# 根据GPU数目自动更改学习率
args = parser.parse_args()
#获取系统环境变量 声明分布式训练的本地序号
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
main()函数,关于参数torch.backends.cudnn.benchmark,辅助阅读进入
def main():
args = parse_args()#获得命令行参数,实际上就是获取config配置文件
#读取配置文件
cfg = Config.fromfile(args.config)
# set cudnn_benchmark
#在图片输入尺度固定时开启,可以加速,一般都是关的,只有在固定尺度的网络如SSD512中才开启
#为整个网络的每个卷积层搜索最适合它的卷积实现算法,进而实现网络的加速
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# update configs according to CLI args
#更新一些配置参数
# 创建工作目录存放训练文件,如果不键入,会自动从py配置文件中生成对应的目录,key为work_dir
if args.work_dir is not None:
cfg.work_dir = args.work_dir
# 断点继续训练的权值文件,为None就没有这一步的设置
if args.resume_from is not None:
cfg.resume_from = args.resume_from
cfg.gpus = args.gpus# gpu数目
# 线性学习率
if args.autoscale_lr:
# apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
cfg.optimizer['lr'] = cfg.optimizer['lr'] * cfg.gpus / 8
# init distributed env first, since logger depends on the dist info.
# 单机训练 默认是none 不使用分布式训练
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# create work_dir 创建文件夹 使用的os.mkdirs()可同时创建多级目录
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# init the logger before other steps
# 初始化一些时间戳,得到一些根日志
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, '{}.log'.format(timestamp))
#log_level在配置文件里有这个key,value=“INFO”训练一次batch就可以看到输出这个str
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta = dict()
# log env info
#得到一些实验环境配置
env_info_dict = collect_env()
env_info = '\n'.join([('{}: {}'.format(k, v))
for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
dash_line)
meta['env_info'] = env_info
# log some basic info
logger.info('Distributed training: {}'.format(distributed))
logger.info('Config:\n{}'.format(cfg.text))
# set random seeds
#设置随机种子,便于实验复现
# 默认为None
if args.seed is not None:
logger.info('Set random seed to {}, deterministic: {}'.format(
args.seed, args.deterministic))
set_random_seed(args.seed, deterministic=args.deterministic)
cfg.seed = args.seed
meta['seed'] = args.seed
# 加载模型
model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
#加载数据集
datasets = [build_dataset(cfg.data.train)]
#如果该列表长度为2,则追加验证数据集
if len(cfg.workflow) == 2:
val_dataset = copy.deepcopy(cfg.data.val)
val_dataset.pipeline = cfg.data.train.pipeline
datasets.append(build_dataset(val_dataset))
#判断模型配置是否为空
if cfg.checkpoint_config is not None:
# save mmdet version, config file content and class names in
# checkpoints as meta data
cfg.checkpoint_config.meta = dict(
mmdet_version=__version__,
config=cfg.text,
CLASSES=datasets[0].CLASSES)
# add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES
#构建训练器
train_detector(
model,
datasets,
cfg,
distributed=distributed,
validate=args.validate,
timestamp=timestamp,
meta=meta)