使用Mxnet进行重构SSD(一)

    在使用Mxnet进行SSD重构的时候,Backbone使用的MobileNetv1,并使用了PreTrainModel的参数进行BackBone参数初始化,剩余网络使用MSRA进行随机初始化。训练初期,Freeze掉Backbone的参数只训练多尺度网络,为防止多尺度网络将Backbone的参数带偏。训练中期,所有的参数全部打开一起训练。训练后期,调整Size、Steps以及网络模型等参数。

遇到问题:训练初期,Finetuning结束后的模型可以进行大概预测,但是训练中期,模型检测不到物体。

尝试解决方法:

        (1)是不是学习率太高,网络收敛不了,参数释放后,多尺度网络参数将Backbone参数带偏。:尝试更改了学习率,但是训练到后期,还是检测不到。

        (2)是不是网络模型架构没有很好的衔接。:将网络每一层的shape打印结果与SSD_Mobilenet的shape相同,排除。

        (3)label转为target失败。:分析代码,进行解析。

解决方法:

        (1)在多尺度检测的时候,我们通过增加了一个1024维的卷积层,提高了模型的检测性能,并且使得检测性能提高诸多。

        (2)在进行多尺度检测的时候,在设置减半网络的时候,是使用MaxPooling、AveragePooling还是1*1卷积改变通道+3*3卷积(stride=2,padding=1),并且分别进行了调试,最后MaxPooling效果更好。

Mxnet进行重构SSD之数据导入

        标注的图片是xml,本次重构使用的是Voc数据格式。

# -*- coding: utf-8 -*-
"""
Created on Sun Feb 24 10:32:08 2019

@author: 20181126
"""
from mxnet.gluon import data
import os
import cv2
try:
    import xml.etree.cElementTree as ET
except ImportError:
    import xml.etree.ElementTree as ET
import numpy as np
import mxnet as mx
class Cigarette_DataSet(data.Dataset):
    def __init__(self,image_path,annotation_path,class_map,data_aug=False):
        self.image_path=image_path
        self.anno_path=annotation_path
        self.class_map=class_map
        self._image_ids=self._load_ids(self.image_path)
        self.index_map=dict(zip(self.classes,range(self.num_classes)))
    def _load_ids(self,path):
        ids=os.listdir(path)
        return ids
    @property
    def classes(self):
        """Category names."""
        '''####class_map是class列表'''
        return self.class_map
    @property
    def num_classes(self):
        """Category names."""
        return len(self.classes)
    def _load_label(self,idx):
        img_id=self._image_ids[idx]
        prefix,suffix=os.path.splitext(img_id)
        xml_path=os.path.join(self.anno_path,prefix+'_detection.xml')
        root = ET.parse(xml_path).getroot()
        size = root.find('size')
        width = float(size.find('width').text)
        height = float(size.find('height').text)
        '''if idx not in self._im_shapes:
            # store the shapes for later usage
            self._im_shapes[idx] = (width, height)'''
        label = []
        for obj in root.iter('object'):
            difficult = int(obj.find('difficult').text)
            cls_name = obj.find('name').text.strip().lower()
            #if cls_name not in self.classes:
                #print('{} is not in {}'.format(cls_name,self.classes))
                #print('123')
                #continue
            #cls_id = self.index_map[cls_name]
            cls_id=0
            xml_box = obj.find('bndbox')
            xmin = (float(xml_box.find('xmin').text) - 1)
            ymin = (float(xml_box.find('ymin').text) - 1)
            xmax = (float(xml_box.find('xmax').text) - 1)
            ymax = (float(xml_box.find('ymax').text) - 1)
            try:
                self._validate_label(xmin, ymin, xmax, ymax, width, height)
            except AssertionError as e:
                raise RuntimeError("Invalid label at {}, {}".format(xml_path, e))
            label.append([xmin, ymin, xmax, ymax, cls_id])
        return np.array(label)
    def _validate_label(self, xmin, ymin, xmax, ymax, width, height):
        """Validate labels."""
        assert 0 <= xmin < width, "xmin must in [0, {}), given {}".format(width, xmin)
        assert 0 <= ymin < height, "ymin must in [0, {}), given {}".format(height, ymin)
        assert xmin < xmax <= width, "xmax must in (xmin, {}], given {}".format(width, xmax)
        assert ymin < ymax <= height, "ymax must in (ymin, {}], given {}".format(height, ymax)
    def __getitem__(self,idx):
        img_path=os.path.join(self.image_path,self._image_ids[idx])
        image=cv2.imread(img_path)
        image=image.astype(np.float32)
        label=self._load_label(idx)
        return mx.nd.array(image),label##nd.array版本的image
    def __len__(self):
        return len(self._image_ids)

每一次return出的图片type是mxnet的nd.array(),label是np.array()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值