GluonCV ----- train_ssd.py

1. 解析命令行

def parse_args():
    parser = argparse.ArgumentParser(description='Train SSD networks.')
    ##--SSD的主干网络
    parser.add_argument('--network', type=str, default='vgg16_atrous',
                        help="Base network name which serves as feature extraction base.")
    ##--SSD网络输入图片的尺寸
    parser.add_argument('--data-shape', type=int, default=300,
                        help="Input data shape, use 300, 512.")
    ##--批量数据大小
    parser.add_argument('--batch-size', type=int, default=32,
                        help='Training mini-batch size')
    ##--训练数据集类型
    parser.add_argument('--dataset', type=str, default='voc',
                        help='Training dataset. Now support voc.')
    ##--数据集存放的路径
    parser.add_argument('--dataset-root', type=str, default='~/.mxnet/datasets/',
                        help='Path of the directory where the dataset is located.')
    ##--多进程加速加载数据
    parser.add_argument('--num-workers', '-j', dest='num_workers', type=int,
                        default=4, help='Number of data workers, you can use larger '
                        'number to accelerate data loading, if you CPU and GPUs are powerful.')
    ##--使用多gpu
    parser.add_argument('--gpus', type=str, default='0',
                        help='Training with GPUs, you can specify 1,3 for example.')
    ##--训练迭代数
    parser.add_argument('--epochs', type=int, default=240,
                        help='Training epochs.')
    ##--继续上次中断之后开始训练
    parser.add_argument('--resume', type=str, default='',
                        help='Resume from previously saved parameters if not None. '
                        'For example, you can resume from ./ssd_xxx_0123.params')
    ##--配合上面的一起用,指定从第几次迭代开始。
    parser.add_argument('--start-epoch', type=int, default=0,
                        help='Starting epoch for resuming, default is 0 for new training.'
                        'You can specify it to 100 for example to start from 100 epoch.')
    ##--学习率
    parser.add_argument('--lr', type=float, default=0.001,
                        help='Learning rate, default is 0.001')
    ##--学习率的衰减率
    parser.add_argument('--lr-decay', type=float, default=0.1,
                        help='decay rate of learning rate. default is 0.1.')
    ##--第几次迭代开始衰减
    parser.add_argument('--lr-decay-epoch', type=str, default='160,200',
                        help='epochs at which learning rate decays. default is 160,200.')
    ##--SGD动量法参数
    parser.add_argument('--momentum', type=float, default=0.9,
                        help='SGD momentum, default is 0.9')
    ##--权重衰减系数(正则化)
    parser.add_argument('--wd', type=float, default=0.0005,
                        help='Weight decay, default is 5e-4')
    ##--指定多少个batch打印一次
    parser.add_argument('--log-interval', type=int, default=100,
                        help='Logging mini-batch interval. Default is 100.')
    ##--训练过程中或者训练完存储参数的时候指定的前缀名
    parser.add_argument('--save-prefix', type=str, default='',
                        help='Saving parameter prefix')
    ##--指定多少个epoch存一次参数,防止训练中断。
    parser.add_argument('--save-interval', type=int, default=10,
                        help='Saving parameters epoch interval, best model will always be saved.')
    ##--指定多少个epoch进行一次验证
    parser.add_argument('--val-interval', type=int, default=1,
                        help='Epoch interval for validation, increase the number will reduce the '
                             'training time if validation is slow.')
    ##--这是什么鬼
    parser.add_argument('--seed', type=int, default=233,
                        help='Random seed to be fixed.')
    ##--这又是什么鬼
    parser.add_argument('--syncbn', action='store_true',
                        help='Use synchronize BN across devices.')
    ##--加速训练的DALI模块
    parser.add_argument('--dali', action='store_true',
                        help='Use DALI for data loading and data preprocessing in training. '
                        'Currently supports only COCO.')
    ##--混合进度训练可用于节省显存和加快速度
    parser.add_argument('--amp', action='store_true',
                        help='Use MXNet AMP for mixed precision training.')
    ##--使用分布式训练框架
    parser.add_argument('--horovod', action='store_true',
                        help='Use MXNet Horovod for distributed training. Must be run with OpenMPI. '
                        '--gpus is ignored when using --horovod.')

    args = parser.parse_args()
    if args.horovod:
        assert hvd, "You are trying to use horovod support but it's not installed"
    return args

2. 获取数据集

##--将源文件的数据解析到类中
def get_dataset(dataset, args):
    if dataset.lower() == 'voc':
    	##--解析VOC类型的数据集。
    	##--1.首先去root路径下,找到VOC2007的文件夹,然后在其子目录中找到trainval.txt的文件。
    	##--2.trainval.txt中存着所有的图片名称,依据这个去解析所有图片的xml文件获取label和bbox等参数。
    	##--3.最后将参数都存在类内。
    	##--4.VOC2012和test2007类似这样。
        train_dataset = gdata.VOCDetection(
            splits=[(2007, 'trainval'), (2012, 'trainval')])
        val_dataset = gdata.VOCDetection(
            splits=[(2007, 'test')])
        ##--评估类初始化一下
        val_metric = VOC07MApMetric(iou_thresh=0.5, class_names=val_dataset.classes)
    elif dataset.lower() == 'coco':
        train_dataset = gdata.COCODetection(root=args.dataset_root + "/coco", splits='instances_train2017')
        val_dataset = gdata.COCODetection(root=args.dataset_root + "/coco", splits='instances_val2017', skip_empty=False)
        val_metric = COCODetectionMetric(
            val_dataset, args.save_prefix + '_eval', cleanup=True,
            data_shape=(args.data_shape, args.data_shape))
        # coco validation is slow, consider increase the validation interval
        if args.val_interval == 1:
            args.val_interval = 10
    else:
        raise NotImplementedError('Dataset: {} not implemented.'.format(dataset))
    ##--返回VOC类型的训练数据类和评估类
    return train_dataset, val_dataset, val_metric

3. 获取可加载入模型的迭代型数据

def get_dataloader(net, train_dataset, val_dataset, data_shape, batch_size, num_workers, ctx):
    """Get dataloader."""
    width, height = data_shape, data_shape
    # use fake data to generate fixed anchors for target generation
    ##--BN, dropout这些层在训练和测试的时候是不同的,
    ##--BN在训练的时候是根据每个mini-batch的均值和方差进行计算并更新参数,在测试的时候是使用训练集上得到的一个参数进行计算。
    ##--dropout在测试的时候是没有的。因此需要在不同的时候加以区别
    ##--使用with autograd.record():,默认为train_mode = True
    with autograd.train_mode():
        _, _, anchors = net(mx.nd.zeros((1, 3, height, width), ctx))
    anchors = anchors.as_in_context(mx.cpu())
    ##--这是一个函数,作用就是将三个数据整理成三个ndarray
    batchify_fn = Tuple(Stack(), Stack(), Stack())  # stack image, cls_targets, box_targets
    ##--1.SSDDefaultTrainTransform内部有__call__函数,进行大量的图像增强动作。
    ##--2.用transform函数会在数据集内填充上述转换function,并且内部巨有__item__函数,因此每次调用都进行转换一次。
    ##--3.batchify表示输入数据的组合方式
    ##--4.last_batch表示除batch之后的余数咋处理
    train_loader = gluon.data.DataLoader(
        train_dataset.transform(SSDDefaultTrainTransform(width, height, anchors)),
        batch_size, True, batchify_fn=batchify_fn, last_batch='rollover', num_workers=num_workers)
    ##--对于 Validation/Test,需要在 DataLoader 的 batchify_fn 里面有一个 Pad 操作(pad_val = -1)。
    ##--这么做是为了计算 mAP 时候方便,DataLoader 返回的是一个 MXNet.NDArray,而不是一个大小不整齐的list。
    val_batchify_fn = Tuple(Stack(), Pad(pad_val=-1))
    val_loader = gluon.data.DataLoader(
        val_dataset.transform(SSDDefaultValTransform(width, height)),
        batch_size, False, batchify_fn=val_batchify_fn, last_batch='keep', num_workers=num_workers)
    return train_loader, val_loader

4. 基于DALI的数据加载

def get_dali_dataset(dataset_name, devices, args):
def get_dali_dataloader(net, train_dataset, val_dataset, data_shape, global_batch_size, num_workers, devices, ctx, horovod, seed):

5. 训练过程中保存参数


def save_params(net, best_map, current_map, epoch, save_interval, prefix):
    current_map = float(current_map)
    ## 1.当MAP创新高的时候存下存下参数
    ## 2.当指定间隔周期到达的时候存下参数 
    if current_map > best_map[0]:
        best_map[0] = current_map
        net.save_params('{:s}_best.params'.format(prefix, epoch, current_map))
        with open(prefix+'_best_map.log', 'a') as f:
            f.write('{:04d}:\t{:.4f}\n'.format(epoch, current_map))
    if save_interval and epoch % save_interval == 0:
        net.save_params('{:s}_{:04d}_{:.4f}.params'.format(prefix, epoch, current_map))

6.验证训练结果

def validate(net, val_data, ctx, eval_metric):
    """Test on validation dataset."""
    ##--初始化一下评估结果
    eval_metric.reset()
    # set nms threshold and topk constraint
    ##--设置非极大值抑制参数
    net.set_nms(nms_thresh=0.45, nms_topk=400)
    net.hybridize(static_alloc=True, static_shape=True)
    ##--在测试集上进行验证
    for batch in val_data:
    	##--将数据分到ctx列表上,可以进行多gpu训练。
    	##--返回类型:list of NDArrays or ndarrays
        data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0, even_split=False)
        label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0, even_split=False)
        det_bboxes = []
        det_ids = []
        det_scores = []
        gt_bboxes = []
        gt_ids = []
        gt_difficults = []
        for x, y in zip(data, label):
            # get prediction results
            ids, scores, bboxes = net(x)
            det_ids.append(ids)
            det_scores.append(scores)
            # clip to image size
            det_bboxes.append(bboxes.clip(0, batch[0].shape[2]))
            # split ground truths
            gt_ids.append(y.slice_axis(axis=-1, begin=4, end=5))
            gt_bboxes.append(y.slice_axis(axis=-1, begin=0, end=4))
            gt_difficults.append(y.slice_axis(axis=-1, begin=5, end=6) if y.shape[-1] > 5 else None)

        # update metric
        eval_metric.update(det_bboxes, det_ids, det_scores, gt_bboxes, gt_ids, gt_difficults)
    return eval_metric.get()

6.训练函数


def train(net, train_data, val_data, eval_metric, ctx, args):
    """Training pipeline"""
    net.collect_params().reset_ctx(ctx)
	##--分布式训练用
    if args.horovod:
        hvd.broadcast_parameters(net.collect_params(), root_rank=0)
        trainer = hvd.DistributedTrainer(
                        net.collect_params(), 'sgd',
                        {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum})
    else:
        trainer = gluon.Trainer(
                    net.collect_params(), 'sgd',
                    {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum},
                    update_on_kvstore=(False if args.amp else None))
	##--混合精度训练用
    if args.amp:
        amp.init_trainer(trainer)
	##--学习率衰减策略
    # lr decay policy
    lr_decay = float(args.lr_decay)
    lr_steps = sorted([float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()])
	##--训练损失函数包括下面两个损失
    mbox_loss = gcv.loss.SSDMultiBoxLoss()
    ##--类别预测损失--打印用
    ce_metric = mx.metric.Loss('CrossEntropy')
    ##--锚框偏移量损失--打印用
    smoothl1_metric = mx.metric.Loss('SmoothL1')

    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    log_file_path = args.save_prefix + '_train.log'
    log_dir = os.path.dirname(log_file_path)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    fh = logging.FileHandler(log_file_path)
    logger.addHandler(fh)
    logger.info(args)
    logger.info('Start training from [Epoch {}]'.format(args.start_epoch))
    best_map = [0]
	##--开始迭代训练
    for epoch in range(args.start_epoch, args.epochs):
    	##--学习率衰减策略
        while lr_steps and epoch >= lr_steps[0]:
            new_lr = trainer.learning_rate * lr_decay
            lr_steps.pop(0)
            trainer.set_learning_rate(new_lr)
            logger.info("[Epoch {}] Set learning rate to {}".format(epoch, new_lr))
        ce_metric.reset()
        smoothl1_metric.reset()
        tic = time.time()
        btic = time.time()
        net.hybridize(static_alloc=True, static_shape=True)
		##-- 
        for i, batch in enumerate(train_data):
            if args.dali:
                # dali iterator returns a mxnet.io.DataBatch
                data = [d.data[0] for d in batch]
                box_targets = [d.label[0] for d in batch]
                cls_targets = [nd.cast(d.label[1], dtype='float32') for d in batch]

            else:
            	##--之前DataLoder的batchify设置了3个stack刚号对应batch三个值。
                data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
                cls_targets = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
                box_targets = gluon.utils.split_and_load(batch[2], ctx_list=ctx, batch_axis=0)
			
            with autograd.record():
                cls_preds = []
                box_preds = []
                ##--计算训练集预测值
                for x in data:
                    cls_pred, box_pred, _ = net(x)
                    cls_preds.append(cls_pred)
                    box_preds.append(box_pred)
                ##--计算损失函数
                sum_loss, cls_loss, box_loss = mbox_loss(
                    cls_preds, box_preds, cls_targets, box_targets)
                if args.amp:
                    with amp.scale_loss(sum_loss, trainer) as scaled_loss:
                        autograd.backward(scaled_loss)
                else:
                    autograd.backward(sum_loss)
            # since we have already normalized the loss, we don't want to normalize
            # by batch-size anymore
            trainer.step(1)
			
            if (not args.horovod or hvd.rank() == 0):
                local_batch_size = int(args.batch_size // (hvd.size() if args.horovod else 1))
                ##--更新一下两个损失的内容
                ce_metric.update(0, [l * local_batch_size for l in cls_loss])
                smoothl1_metric.update(0, [l * local_batch_size for l in box_loss])
                ##--到打印周期之后打印一下两个损失
                if args.log_interval and not (i + 1) % args.log_interval:
                    name1, loss1 = ce_metric.get()
                    name2, loss2 = smoothl1_metric.get()
                    logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}={:.3f}, {}={:.3f}'.format(
                        epoch, i, args.batch_size/(time.time()-btic), name1, loss1, name2, loss2))
                btic = time.time()
		
        if (not args.horovod or hvd.rank() == 0):
            name1, loss1 = ce_metric.get()
            name2, loss2 = smoothl1_metric.get()
            logger.info('[Epoch {}] Training cost: {:.3f}, {}={:.3f}, {}={:.3f}'.format(
                epoch, (time.time()-tic), name1, loss1, name2, loss2))
            ##--到指代的迭代数之后,使用验证集验证一下,计算mAP。
            if (epoch % args.val_interval == 0) or (args.save_interval and epoch % args.save_interval == 0):
                # consider reduce the frequency of validation to save time
                map_name, mean_ap = validate(net, val_data, ctx, eval_metric)
                val_msg = '\n'.join(['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)])
                logger.info('[Epoch {}] Validation: \n{}'.format(epoch, val_msg))
                current_map = float(mean_ap[-1])
            else:
                current_map = 0.
            ##--计算完map之后存一下参数
            save_params(net, best_map, current_map, epoch, args.save_interval, args.save_prefix)

7.主函数

if __name__ == '__main__':
    args = parse_args()

    if args.amp:
        amp.init()

    if args.horovod:
        hvd.init()

    # fix seed for mxnet, numpy and python builtin random generator.
    gutils.random.seed(args.seed)

    # training contexts
    if args.horovod:
        ctx = [mx.gpu(hvd.local_rank())]
    else:
        ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()]
        ctx = ctx if ctx else [mx.cpu()]

    # network
    net_name = '_'.join(('ssd', str(args.data_shape), args.network, args.dataset))
    args.save_prefix += net_name
    if args.syncbn and len(ctx) > 1:
        net = get_model(net_name, pretrained_base=True, norm_layer=gluon.contrib.nn.SyncBatchNorm,
                        norm_kwargs={'num_devices': len(ctx)})
        async_net = get_model(net_name, pretrained_base=False)  # used by cpu worker
    else:
        net = get_model(net_name, pretrained_base=True, norm_layer=gluon.nn.BatchNorm)
        async_net = net
    if args.resume.strip():
        net.load_parameters(args.resume.strip())
        async_net.load_parameters(args.resume.strip())
    else:
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            net.initialize()
            async_net.initialize()
            # needed for net to be first gpu when using AMP
            net.collect_params().reset_ctx(ctx[0])

    # training data
    if args.dali:
        if not dali_found:
            raise SystemExit("DALI not found, please check if you installed it correctly.")
        devices = [int(i) for i in args.gpus.split(',') if i.strip()]
        train_dataset, val_dataset, eval_metric = get_dali_dataset(args.dataset, devices, args)
        train_data, val_data = get_dali_dataloader(
            async_net, train_dataset, val_dataset, args.data_shape, args.batch_size, args.num_workers,
            devices, ctx[0], args.horovod, args.seed)
    else:
        train_dataset, val_dataset, eval_metric = get_dataset(args.dataset, args)
        batch_size = (args.batch_size // hvd.size()) if args.horovod else args.batch_size
        train_data, val_data = get_dataloader(
            async_net, train_dataset, val_dataset, args.data_shape, batch_size, args.num_workers, ctx[0])



    # training
    train(net, train_data, val_data, eval_metric, ctx, args)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值