mxnet复现SSD之训练脚本的实现

mxnet复现SSD系列文章目录

一、数据集的导入.
二、SSD模型架构.
三、训练脚本的实现.
四、损失、评价函数.
五、预测结果.



前言

本项目是按照pascal voc的格式读取数据集,数据集为kaggle官网提供的口罩检测数据集,地址:Face Mask Detection,模型架构参考自gluoncv ssd_300_vgg16_atrous_voc源码


一、代码实现

import mxnet as mx
from mxnet import autograd, contrib, gluon, nd
from train.smooth_l1 import smooth_l1, FocalLoss
from utils import utils
from tools.draw_details import draw_details
from model.vgg_ssd import get_model

import os
import time
import logging


# 主训练函数
def train(data_path, num_classes, data_size, batch_size, epochs, wd, momentum, lr, save_model_path, log_file_path, ctx=mx.cpu()):

    train_rec_path = os.path.join(data_path, 'train.rec')
    train_idx_path = os.path.join(data_path, 'train.idx')
    val_rec_path = os.path.join(data_path, 'val.rec')

	# 数据预处理部分
    augs = mx.image.CreateDetAugmenter(data_shape=(3, data_size, data_size),
                                       rand_crop=1,
                                       min_object_covered=0.9,
                                       aspect_ratio_range=(0.5, 2),
                                       area_range=(0.1, 1.5),
                                       max_attempts=100,
                                       rand_mirror=True,
                                       rand_gray=0.2,
                                       brightness=0.5,
                                       contrast=0.5,
                                       saturation=0.5,
                                       rand_pad=0.4,
                                       hue=0.5,
                                       mean=True,
                                       std=True,
                                       )
	# 训练集
    train_iter = mx.image.ImageDetIter(
        path_imgidx=train_idx_path,
        path_imgrec=train_rec_path,
        batch_size=batch_size,
        data_shape=(3, data_size, data_size),
        shuffle=True,
        aug_list=augs,
    )
	# 验证集
    val_iter = mx.image.ImageDetIter(
        path_imgrec=val_rec_path,
        batch_size=batch_size,
        data_shape=(3, data_size, data_size),
        shuffle=False,
        mean=True,
        std=True
    )
	# 加载模型
    net = get_model(num_classes, pretrained_base=True, ctx=ctx)
    net.collect_params().reset_ctx(ctx=ctx)
    net.hybridize()

    # lrs = mx.lr_scheduler.FactorScheduler(step=200, factor=0.8, stop_factor_lr=lr, base_lr=lr)
	# 优化器
    trainer = gluon.Trainer(net.collect_params(), 'sgd',
                            {'learning_rate': lr, 'wd': wd, 'momentum': momentum})
	# 损失函数
    cls_loss = gluon.loss.SoftmaxCrossEntropyLoss()
    # cls_loss = FocalLoss()
    bbox_loss = smooth_l1()
	# 评价函数
    def evaluate_accuracy(data_iter, net, ctx):
        """
        :param data_iter: 数据集加载器
        :param net: 模型网络
        :param ctx: 可使用的gpu列表
        :return: 验证集准确率
        """

        data_iter.reset()
        outs, labels = None, None
        for batch in data_iter:
            X = batch.data[0].as_in_context(ctx)
            Y = batch.label[0].as_in_context(ctx)

            anchors,bbox_preds,cls_preds = net(X)
            # 为每个锚框标注类别和偏移量
            cls_probs = nd.SoftmaxActivation(cls_preds.transpose((0, 2, 1)), mode='channel')
            out = nd.contrib.MultiBoxDetection(cls_probs, bbox_preds, anchors,
                                               force_suppress=True, clip=False, nms_threshold=0.45)
            if outs is None:
                outs = out
                labels = Y
            else:
                outs = nd.concat(outs, out, dim=0)
                labels = nd.concat(labels, Y, dim=0)

            AP = utils.evaluate_MAP(outs, labels)

            return AP
	
	# 打印训练日志
    # set up logger
    logging.basicConfig(format='%(asctime)s %(message)s')
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    fh = logging.FileHandler(log_file_path, mode='w')
    # 定义handler的输出格式
    formatter = logging.Formatter("%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s")
    fh.setFormatter(formatter)

    logger.addHandler(fh)

    ce_metric = mx.metric.Loss('CrossEntropy')
    smoothl1_metric = mx.metric.Loss('SmoothL1')

	# 训练
    for epoch in range(epochs):

        ce_metric.reset()
        smoothl1_metric.reset()

        train_iter.reset()  # 从头读取数据
        btic = time.time()

        for i, batch in enumerate(train_iter):
            X = batch.data[0].as_in_context(ctx)
            Y = batch.label[0].as_in_context(ctx)
            with autograd.record():
                # 生成多尺度的锚框,为每个锚框预测类别和偏移量
                anchors, bbox_preds, cls_preds = net(X)
                # 为每个锚框标注类别和偏移量
                bbox_labels, bbox_masks, cls_labels = contrib.nd.MultiBoxTarget(
                    anchors, Y, cls_preds.transpose((0, 2, 1)),
                    negative_mining_ratio=3, negative_mining_thresh=.5)
                # 根据类别和偏移量的预测和标注值计算损失函数
                cls = cls_loss(cls_preds, cls_labels)
                bbox = bbox_loss(bbox_preds * bbox_masks, bbox_labels * bbox_masks)
                l = cls + bbox
            l.backward()
            trainer.step(batch_size)

            if i % 50 == 0:
                ce_metric.update(0, cls)
                smoothl1_metric.update(0, bbox)
                name1, loss1 = ce_metric.get()
                name2, loss2 = smoothl1_metric.get()
                val_AP = evaluate_accuracy(val_iter, net, ctx)
                logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}={:.2e}, {}={:.2e}, val_AP={:.3f}'.format(
                    epoch, i, batch_size / (time.time() - btic), name1, loss1, name2, loss2, val_AP))
            btic = time.time()

	# 保存训练的模型参数
    net.save_parameters(save_model_path)
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值