mxnet复现SSD系列文章目录
一、数据集的导入.
二、SSD模型架构.
三、训练脚本的实现.
四、损失、评价函数.
五、预测结果.
前言
本项目是按照pascal voc的格式读取数据集,数据集为kaggle官网提供的口罩检测数据集,地址:Face Mask Detection,模型架构参考自gluoncv ssd_300_vgg16_atrous_voc源码
一、损失函数
类别损失函数采用SoftmaxCrossEntropyLoss
boundingbox损失函数采用smooth_l1或focal loss
代码实现
from mxnet.gluon.loss import Loss
class smooth_l1(Loss):
def __init__(self, weight=None, batch_axis=0, **kwargs):
super(smooth_l1, self).__init__(weight, batch_axis, **kwargs)
def hybrid_forward(self, F, pred, label):
loss = F.smooth_l1(pred-label, scalar=1.)
return F.mean(loss, axis=self._batch_axis, exclude=True)
class FocalLoss(Loss):
def __init__(self,axis=-1,alpha=0.25,gamma=2,batch_axis=0,**kwargs):
super(FocalLoss,self).__init__(None,batch_axis,**kwargs)
self.alpha = alpha
self.gamma = gamma
self.axis = axis
self.batch_axis = batch_axis
def hybrid_forward(self, F, y, label):
y = F.softmax(y)
pt = F.pick(y, label, axis=self.axis, keepdims=True)
loss = -self.alpha * ((1 - pt) ** self.gamma) * F.log(pt)
return F.mean(loss, axis=self._batch_axis, exclude=True)
二、评价函数
评价函数采用计算每个类别的recall, precision和AP值
代码实现
class_recs = {}
npos = 0
for imagename in imagenames:
R = [obj for obj in recs[imagename] if obj['name'] == classname]
bbox = np.array([x['bbox'] for x in R])
difficult = np.array([x['difficult'] for x in R]).astype(np.bool)
det = [False] * len(R) #这个值是用来判断是否重复检测的
npos = npos + sum(~difficult)
class_recs[imagename] = {'bbox': bbox,
'difficult': difficult,
'det': det}
# read dets
detfile = detpath.format(classname)
with open(detfile, 'r') as f:
lines = f.readlines()
splitlines = [x.strip().split(' ') for x in lines]
image_ids = [x[0] for x in splitlines]
confidence = np.array([float(x[1]) for x in splitlines])
BB = np.array([[float(z) for z in x[2:]] for x in splitlines])
# sort by confidence
sorted_ind = np.argsort(-confidence)
BB = BB[sorted_ind, :]
image_ids = [image_ids[x] for x in sorted_ind]
# go down dets and mark TPs and FPs
nd = len(image_ids)
tp = np.zeros(nd)
fp = np.zeros(nd)
for d in range(nd):
R = class_recs[image_ids[d]]
bb = BB[d, :].astype(float)
ovmax = -np.inf
BBGT = R['bbox'].astype(float)
if BBGT.size > 0:
# compute overlaps
# intersection
ixmin = np.maximum(BBGT[:, 0], bb[0])
iymin = np.maximum(BBGT[:, 1], bb[1])
ixmax = np.minimum(BBGT[:, 2], bb[2])
iymax = np.minimum(BBGT[:, 3], bb[3])
iw = np.maximum(ixmax - ixmin + 1., 0.)
ih = np.maximum(iymax - iymin + 1., 0.)
inters = iw * ih
# union
uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) +
(BBGT[:, 2] - BBGT[:, 0] + 1.) *
(BBGT[:, 3] - BBGT[:, 1] + 1.) - inters)
overlaps = inters / uni
ovmax = np.max(overlaps)
jmax = np.argmax(overlaps)
if ovmax > ovthresh:
if not R['difficult'][jmax]:
if not R['det'][jmax]:
tp[d] = 1.
R['det'][jmax] = 1 #判断是否重复检测,检测过一次以后,值就从False变为1了
else:
fp[d] = 1.
else:
fp[d] = 1.
# compute precision recall
fp = np.cumsum(fp)
tp = np.cumsum(tp)
rec = tp / float(npos)
# avoid divide by zero in case the first detection matches a difficult
# ground truth
prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
ap = voc_ap(rec, prec, use_07_metric)
return rec, prec, ap
计算AP值
def voc_ap(rec, prec, use_07_metric=False):
"""Compute VOC AP given precision and recall. If use_07_metric is true, uses
the VOC 07 11-point method (default:False).
"""
if use_07_metric:
# 11 point metric
ap = 0.
for t in np.arange(0., 1.1, 0.1):
if np.sum(rec >= t) == 0:
p = 0
else:
p = np.max(prec[rec >= t])
ap = ap + p / 11.
else:
# correct AP calculation
# first append sentinel values at the end
mrec = np.concatenate(([0.], rec, [1.]))
mpre = np.concatenate(([0.], prec, [0.]))
# compute the precision envelope
for i in range(mpre.size - 1, 0, -1):
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
i = np.where(mrec[1:] != mrec[:-1])[0]
# and sum (\Delta recall) * prec
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) #计算面积
return ap
计算mAP值
def mAP():
detpath,annopath,imagesetfile,cachedir,class_path = get_dir('kitti')
ovthresh=0.3,
use_07_metric=False
rec = 0; prec = 0; mAP = 0
class_list = get_classlist(class_path)
for classname in class_list:
rec, prec, ap = voc_eval(detpath,
annopath,
imagesetfile,
classname,
cachedir,
ovthresh=0.5,
use_07_metric=False,
kitti=True)
print('on {}, the ap is {}, recall is {}, precision is {}'.format(classname, ap, rec[-1], prec[-1]))
mAP += ap
mAP = float(mAP) / len(class_list)
return mAP