【Augmentation Zoo】RetinaNet + VOC + KITTI的数据预处理-pytorch版

整合前段时间看的数据增强方法,并测试其在VOC和KITTI数据上的效果。我的工作是完成了对VOC和KITTI数据的预处理,RetinaNet的模型代码来自pytorch-retinanet

该项目github仓库在:https://github.com/zzl-pointcloud/Data_Augmentation_Zoo_for_Object_Detection

 

目录

一、VOC数据预处理

二、KITTI数据预处理

三、Resizer类和collater()类

1. Resizer类

2. Collater类


整个代码的处理逻辑是:

  1. 继承torch.Dataset类定义新的数据集类,如VocDatasets类,KittiDatasets类,重写__getitem__(image_index)函数,其功能是,输入图片序号,返回一个sample = {'img': img, 'annots': annots}。类中其他函数均服务于__getitem__函数,如load_image(),load_annotations()等。
  2. 将transform传入Dataset中,transform.Compose([fun1(), fun2(), ...])。其中fun是object继承类,定义其中的__call__(),使得他们可以被作为函数使用。对每张图片顺序执行函数fun1(), fun2(), ...。这里的fun()就是数据增强方法的入口
  3.  sampler(从数据集中取样本的策略)处理后,数据集类转换为DataLoader对象。通过sampler中设置的yield,迭代返回每一次的数据。
  4. 至此数据预处理部分完成,送入模型开始训练。训练的数据流动我个人理解如下:

每个epoch将整个数据在模型中走一遍,至于取样本的策略由第3步的sampler决定。每个epoch中,数据集N = batch_size * iter_num,每个iter,前向传播、反向传播,验证集测试,保存模型等。

retinanet

optimizer = optim.Adam(retinanet.parameters(), lr=1e-5)


for epoch_num in range(epochs):
    for iter_num, data in enumerate(dataloader_train):
        # 前向传播,求解loss
        retinanet.train()
        classification_loss, regression_loss = retinanet([data['img'].float, data['annot']])  
        classification_loss = classification_loss.mean()
        regression_loss = regression_loss.mean()
        loss = classification_loss + regression_loss
        
        #反向传播,更新权重
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    """
    validation part

    """

"""
test part

"""
# 保存模型
torch.save(retinanet, "model_final.pt")

一、VOC数据预处理

class VocDataset(Dataset):
    def __init__(self,
                 root_dir,
                 image_set='train',         # train/val/test
                 years=['2007', '2012'],    # 默认2007+2012
                 transform=None,
                 keep_difficult=False
                 ):
        self.root_dir = root_dir
        self.years = years
        self.image_set = image_set
        self.transform = transform
        self.keep_difficult = keep_difficult

        self.categories = VOC_CLASSES

        self.name_2_label = dict(
            zip(self.categories, range(len(self.categories)))
        )
        self.label_2_name = {
            v: k
            for k, v in self.name_2_label.items()
        }
        self.ids = list()
        self.find_file_list()

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, image_index):

        img = self.load_image(image_index)
        annots = self.load_annotations(image_index)
        sample = {'img':img, 'annot':annots}
        if self.transform:
            sample = self.transform(sample)
        return sample

    def find_file_list(self):
        for year in self.years:
            if not (year == '2012' and self.image_set == 'test'):
                root_path = os.path.join(self.root_dir, 'VOC' + year)
                file_path = os.path.join(root_path, 'ImageSets', 'Main', self.image_set + '.txt')
                for line in open(file_path):
                    self.ids.append((root_path, line.strip()))

    def load_image(self, image_index):
        image_root_dir, img_idx = self.ids[image_index]
        image_path = os.path.join(image_root_dir,
                                 'JPEGImages', img_idx + '.jpg')
        img = cv2.imread(image_path)
        if len(img.shape) == 2:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        return img.astype(np.float32)/255.0

    def load_annotations(self, image_index):
        image_root_dir, img_idx = self.ids[image_index]
        anna_path = os.path.join(image_root_dir,
                                'Annotations', img_idx + '.xml')
        annotations = []
        target = ET.parse(anna_path).getroot()
        for obj in target.iter("object"):
            difficult = int(obj.find('difficult').text) == 1
            if not self.keep_difficult and difficult:
                continue
            bbox = obj.find('bndbox')

            pts = ['xmin', 'ymin', 'xmax', 'ymax']

            bndbox = []
            for pt in pts:
                cut_pt = bbox.find(pt).text
                bndbox.append(float(cut_pt))
            name = obj.find('name').text.lower().strip()
            label = self.name_2_label[name]
            bndbox.append(label)
            annotations += [bndbox]
        annotations = np.array(annotations)

        return annotations

    def label_to_name(self, voc_label):
        return self.label_2_name[voc_label]

    def name_to_label(self, voc_name):
        return self.name_2_label[voc_name]

    def image_aspect_ratio(self, image_index):
        image_root_dir, img_idx = self.ids[image_index]
        image_path = os.path.join(image_root_dir,
                                  'JPEGImages', img_idx + '.jpg')
        img = cv2.imread(image_path)
        return float(img.shape[1] / float(img.shape[0]))

    def num_classes(self):
        return 20

二、KITTI数据预处理

对KITTI的数据预处理代码上与VOC相似,但在初始化KittiDataset类之前,需要先将KITTI数据集人工划分为训练/验证集,并生成类似于VOC中的train.txt和val.txt文件。因此我又实现了SplitKittiDataset类(在tools.py中),大概思路是:

1. 获得文件名list,及len

2. 用range(len)生成一个index,打乱后,按划分比例取train_index和val_index,然后从list中取对应的文件名

3. 保存到txt文件中。

class KittiDataset(Dataset):
    def __init__(self,
                 root_dir,
                 sets,
                 transform=None,
                 keep_difficult=False
                 ):
        self.root_dir = root_dir
        self.sets = sets
        self.transform = transform
        self.keep_difficult = keep_difficult

        self.categories = KITTI_CLASSES

        self.name_2_label = dict(
            zip(self.categories, range(len(self.categories)))
        )
        self.label_2_name = {
            v: k
            for k, v in self.name_2_label.items()
        }
        self.ids = list()
        self.find_file_list()

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, image_index):
        img = self.load_image(image_index)
        annot = self.load_annotations(image_index)
        sample = {'img':img, 'annot':annot}
        if self.transform:
            sample = self.transform(sample)
        return sample

    def find_file_list(self):
        file_path = os.path.join(self.root_dir, self.sets + '.txt')
        for line in open(file_path):
            self.ids.append(line.strip())

    def load_image(self, image_index):
        img_idx = self.ids[image_index]
        image_path = os.path.join(self.root_dir,
                                 'image_2', img_idx + '.png')
        img = cv2.imread(image_path)
        if len(img.shape) == 2:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        return img.astype(np.float32)/255.0

    def load_annotations(self, image_index):
        img_idx = self.ids[image_index]
        anna_path = os.path.join(self.root_dir,
                                'label_2', img_idx + '.txt')
        annotations = []
        with open(anna_path) as file:
            lines = file.readlines()
            for line in lines:
                items = line.split(" ")
                name = items[0].lower().strip()
                if name == 'dontcare':
                    continue
                else:
                    bndbox = [float(items[i+4]) for i in range(4)]
                    label = self.name_2_label[name]
                    bndbox.append(int(label))
                annotations.append(bndbox)
        annotations = np.array(annotations)
        return annotations

    def label_to_name(self, voc_label):
        return self.label_2_name[voc_label]

    def name_to_label(self, voc_name):
        return self.name_2_label[voc_name]

    def image_aspect_ratio(self, image_index):
        img_idx = self.ids[image_index]
        image_path = os.path.join(self.root_dir,
                                  'image_2', img_idx + '.png')
        img = cv2.imread(image_path)
        return float(img.shape[1] / float(img.shape[0]))

    def num_classes(self):
        return 8

三、Resizer类和collater()类

分别用于将图片修改为限定大小和对齐。

1. Resizer类

设置短边上限和长边上限,如608/1024
scale = 短边上限 / 短边
if 长边 * scale > 长边上限:
    scale = 长边上限 / 长边
resized_image = cv2.resize(image, 长边 * scale, 短边 * scale)

将resized_image长宽填充为32的倍数

2. Collater类

图片填充为数据集最长宽和最长高(如:长边上限 * 长边上限),图片从左上角对齐,其余部分填充为0。

annots填充则是以sample为单位,都扩充到最多目标数,其余填充-1

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值