在使用Mxnet进行SSD重构的时候,Backbone使用的MobileNetv1,并使用了PreTrainModel的参数进行BackBone参数初始化,剩余网络使用MSRA进行随机初始化。训练初期,Freeze掉Backbone的参数只训练多尺度网络,为防止多尺度网络将Backbone的参数带偏。训练中期,所有的参数全部打开一起训练。训练后期,调整Size、Steps以及网络模型等参数。
遇到问题:训练初期,Finetuning结束后的模型可以进行大概预测,但是训练中期,模型检测不到物体。
尝试解决方法:
(1)是不是学习率太高,网络收敛不了,参数释放后,多尺度网络参数将Backbone参数带偏。:尝试更改了学习率,但是训练到后期,还是检测不到。
(2)是不是网络模型架构没有很好的衔接。:将网络每一层的shape打印结果与SSD_Mobilenet的shape相同,排除。
(3)label转为target失败。:分析代码,进行解析。
解决方法:
(1)在多尺度检测的时候,我们通过增加了一个1024维的卷积层,提高了模型的检测性能,并且使得检测性能提高诸多。
(2)在进行多尺度检测的时候,在设置减半网络的时候,是使用MaxPooling、AveragePooling还是1*1卷积改变通道+3*3卷积(stride=2,padding=1),并且分别进行了调试,最后MaxPooling效果更好。
Mxnet进行重构SSD之数据导入
标注的图片是xml,本次重构使用的是Voc数据格式。
# -*- coding: utf-8 -*-
"""
Created on Sun Feb 24 10:32:08 2019
@author: 20181126
"""
from mxnet.gluon import data
import os
import cv2
try:
import xml.etree.cElementTree as ET
except ImportError:
import xml.etree.ElementTree as ET
import numpy as np
import mxnet as mx
class Cigarette_DataSet(data.Dataset):
def __init__(self,image_path,annotation_path,class_map,data_aug=False):
self.image_path=image_path
self.anno_path=annotation_path
self.class_map=class_map
self._image_ids=self._load_ids(self.image_path)
self.index_map=dict(zip(self.classes,range(self.num_classes)))
def _load_ids(self,path):
ids=os.listdir(path)
return ids
@property
def classes(self):
"""Category names."""
'''####class_map是class列表'''
return self.class_map
@property
def num_classes(self):
"""Category names."""
return len(self.classes)
def _load_label(self,idx):
img_id=self._image_ids[idx]
prefix,suffix=os.path.splitext(img_id)
xml_path=os.path.join(self.anno_path,prefix+'_detection.xml')
root = ET.parse(xml_path).getroot()
size = root.find('size')
width = float(size.find('width').text)
height = float(size.find('height').text)
'''if idx not in self._im_shapes:
# store the shapes for later usage
self._im_shapes[idx] = (width, height)'''
label = []
for obj in root.iter('object'):
difficult = int(obj.find('difficult').text)
cls_name = obj.find('name').text.strip().lower()
#if cls_name not in self.classes:
#print('{} is not in {}'.format(cls_name,self.classes))
#print('123')
#continue
#cls_id = self.index_map[cls_name]
cls_id=0
xml_box = obj.find('bndbox')
xmin = (float(xml_box.find('xmin').text) - 1)
ymin = (float(xml_box.find('ymin').text) - 1)
xmax = (float(xml_box.find('xmax').text) - 1)
ymax = (float(xml_box.find('ymax').text) - 1)
try:
self._validate_label(xmin, ymin, xmax, ymax, width, height)
except AssertionError as e:
raise RuntimeError("Invalid label at {}, {}".format(xml_path, e))
label.append([xmin, ymin, xmax, ymax, cls_id])
return np.array(label)
def _validate_label(self, xmin, ymin, xmax, ymax, width, height):
"""Validate labels."""
assert 0 <= xmin < width, "xmin must in [0, {}), given {}".format(width, xmin)
assert 0 <= ymin < height, "ymin must in [0, {}), given {}".format(height, ymin)
assert xmin < xmax <= width, "xmax must in (xmin, {}], given {}".format(width, xmax)
assert ymin < ymax <= height, "ymax must in (ymin, {}], given {}".format(height, ymax)
def __getitem__(self,idx):
img_path=os.path.join(self.image_path,self._image_ids[idx])
image=cv2.imread(img_path)
image=image.astype(np.float32)
label=self._load_label(idx)
return mx.nd.array(image),label##nd.array版本的image
def __len__(self):
return len(self._image_ids)
每一次return出的图片type是mxnet的nd.array(),label是np.array()