pytorch框架下ssd代码的解析

背景

上一篇文章写了pytorch版本yolov3的源码。代码较为简单。这篇文章准备写一篇代码较为复杂的SSD实现版本。该版本的github地址为:

https://github.com/amdegroot/ssd.pytorch

在该github下的使用操作方法比较完善,就不在这里记录了。在这里只记录代码的解析。

数据读入部分

数据读入部分的代码为

dataset=VOCDetection(root=args.dataset_root,transform=SSDAugmentation(cfg['min_dim'],MEANS))

transform

其中SSDAugmentation函数为图像、标签转换函数,具体的定义为之在utils/augmentations.py

class SSDAugmentation(object):
    def __init__(self, size=300, mean=(104, 117, 123)):
        self.mean = mean
        self.size = size
        self.augment = Compose([
            ConvertFromInts(),                      #数据类型转换
            ToAbsoluteCoords(),                     #位置信息转换
            PhotometricDistort(),                   #镜像翻转
            Expand(self.mean),                      #扩展图像
            RandomSampleCrop(),                     #随机裁剪
            RandomMirror(),                         #随机镜像翻转
            ToPercentCoords(),                      #位置归一化
            Resize(self.size),                      #图像尺寸缩放
            SubtractMeans(self.mean)                #图像去均值
        ])

    def __call__(self, img, boxes, labels):
        return self.augment(img, boxes, labels)

其中__call__方法是python的一个方法,该方法的定义表明该类可以直接调用。
在augmentations.py中详细的定义了每一个数据转换的方法,具体的定义就不叙述了。

Dataset

类VOCDetetction继承于torch.utils.data.Dataset类别,需要定义imgs,getitem、__len__方法。
python的__getitem__方法可以让对象实现迭代功能。在这里,会返回单张图像及其标签。定义如下:

class VOCDetection(data.Dataset):
    def __init__(self, root,
                 image_sets=[('2007', 'trainval'), ('2012', 'trainval')],
                 transform=None, target_transform=VOCAnnotationTransform(),
                 dataset_name='VOC0712'):
        self.root = root									#设置数据集的根目录
        self.image_set = image_sets							#设置要选用的数据集
        self.transform = transform							#定义图像转换方法
        self.target_transform = target_transform			#定义标签的转换方法
        self.name = dataset_name							#定义数据集名称
        self._annopath = osp.join('%s', 'Annotations', '%s.xml')	#记录标签的位置
        self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg')		#记录图像的位置
        self.ids = list()											#记录数据集中的所有图像的名字
        #读入数据集中的图像名称,可以依照该名称和_annopath、_imgpath推断出图片、描述文件存储的位置
        for (year, name) in image_sets:
            rootpath = osp.join(self.root, 'VOC' + year)
            for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')):
                self.ids.append((rootpath, line.strip()))

    def __getitem__(self, index):
        im, gt, h, w = self.pull_item(index)
        return im, gt
    def __len__(self):
        return len(self.ids)

    def pull_item(self, index):
        img_id = self.ids[index]								#获取index对应的img名称
        target = ET.parse(self._annopath % img_id).getroot()	#读取xml文件
        img = cv2.imread(self._imgpath % img_id)				#获取图像
        height, width, channels = img.shape						#获取图像的尺寸
        if self.target_transform is not None:
            target = self.target_transform(target, width, height)          #获取target
        if self.transform is not None:
            target = np.array(target)
            img, boxes, labels = self.transform(img, target[:, :4], target[:, 4])   #对图像、target进行转换
            # to rgb
            img = img[:, :, (2, 1, 0)]		#opencv读入图像的顺序是BGR,该操作将图像转为RGB
            # img = img.transpose(2, 0, 1)
            target = np.hstack((boxes, np.expand_dims(labels, axis=1)))
        return torch.from_numpy(img).permute(2, 0, 1), target, height, width        #返回image、label、宽高.这里的permute(2,0,1)是将原有的三维(28,28,3)变为(3,28,28),将通道数提前,为了统一torch的后续训练操作。

    def pull_image(self, index):
        img_id = self.ids[index]
        return cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR)
    def pull_anno(self, index):
        img_id = self.ids[index]
        anno = ET.parse(self._annopath % img_id).getroot()
        gt = self.target_transform(anno, 1, 1)
        return img_id[1], gt
    def pull_tensor(self, index):
        return torch.Tensor(self.pull_image(index)).unsqueeze_(0)

DataLoader

实际的使用过程中,使用DataLoader,批量的读入数据。不知道什么原因,在windows下执行的时候,worker只能设置为0,否则跑不起来。DataLoader实现了一个并行读入图像、标签的功能。

    data_loader = data.DataLoader(dataset, args.batch_size,                         #数据loader
                                  num_workers=args.num_workers,
                                  shuffle=True, collate_fn=detection_collate,
                                  pin_memory=True)

data_loader可以用以下两种方式调用

#方法一
for (img,target) in data_loader
	print(img,shape)  #height,width,channels
	print(target)
#方法二:
batch_iterator=iter(data_loader)
img,target=next(batch_iterator)

data_loader获取到的数据为(batch,channels,height,width),复合图像网络层需要的数据定义。

网络定义

网络定义在ssd.py文件中。网络定义继承于torch.nn.Module。torch的优势之一是能够进行自动求导的过程。当定义了网络的正向的传播方向,会依照结构进行反向的传播过程,因此在网络的定义过程中只需要定义每个层以及对应层的forward功能。
作者将ssd网络定义分成两个部分。第一步获取各个模块的层,用list进行装填。第二步实现forward的串接

def build_ssd(phase, size=300, num_classes=21):
    if phase != "test" and phase != "train":
        print("ERROR: Phase: " + phase + " not recognized")
        return
    if size != 300:
        print("ERROR: You specified size " + repr(size) + ". However, " +
              "currently only SSD300 (size=300) is supported!")
        return
    base_, extras_, head_ = multibox(vgg(base[str(size)], 3),
                                     add_extras(extras[str(size)], 1024),
                                     mbox[str(size)], num_classes)		#获取三部分需要的卷积层
    return SSD(phase, size, base_, extras_, head_, num_classes)

multibox生成三部分网络层,分别用list装填(实际上,head_为两个module,分别用于计算类别和位置偏移量)
base_为基础网络部分可更换,extras_为基础网络后降采样的部分,head_为类别、位置偏移量计算的卷积部分。

super() 函数是用于调用父类(超类)的一个方法

已标记关键词 清除标记
©️2020 CSDN 皮肤主题: 大白 设计师:CSDN官方博客 返回首页