单目标跟踪:数据集处理

日萌社

人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新)


CNN:RCNN、SPPNet、Fast RCNN、Faster RCNN、YOLO V1 V2 V3、SSD、FCN、SegNet、U-Net、DeepLab V1 V2 V3、Mask RCNN

单目标跟踪SiamMask:特定目标车辆追踪 part1

单目标跟踪SiamMask:特定目标车辆追踪 part2

单目标跟踪 Siamese系列网络:SiamFC、SiamRPN、one-shot跟踪、one-shotting单样本学习、DaSiamRPN、SiamRPN++、SiamMask

单目标跟踪:跟踪效果

单目标跟踪:数据集处理

单目标跟踪:模型搭建

单目标跟踪:模型训练

单目标跟踪:模型测试


1.3 DataSet

学习目标:

  • 了解图像增强的方法
  • 知道正负样本的处理方法

数据处理的程序在文件夹dataset中的siam_mask_dataset.py中,该文件中主要包含四个类和一个方法:

1.DataSets

DataSets是进行数据获取的出口。

4.1 初始化

初始化主要是参数的设置,主要包括以下内容:

代码如下:

class DataSets(Dataset):
    def __init__(self, cfg, anchor_cfg, num_epoch=1):
        super(DataSets, self).__init__()
        global logger
        logger = logging.getLogger('global')

        # anchors
        self.anchors = Anchors(anchor_cfg)
        # size
        self.template_size = 127
        self.origin_size = 127
        self.search_size = 255
        self.size = 17
        self.base_size = 0
        self.crop_size = 0
        # 根据配置文件更新参数
        if 'template_size' in cfg:
            self.template_size = cfg['template_size']
        if 'origin_size' in cfg:
            self.origin_size = cfg['origin_size']
        if 'search_size' in cfg:
            self.search_size = cfg['search_size']
        if 'base_size' in cfg:
            self.base_size = cfg['base_size']
        if 'size' in cfg:
            self.size = cfg['size']

        if (self.search_size - self.template_size) / self.anchors.stride + 1 + self.base_size != self.size:
            raise Exception("size not match!")  # TODO: calculate size online
        if 'crop_size' in cfg:
            self.crop_size = cfg['crop_size']
        self.template_small = False
        if 'template_small' in cfg and cfg['template_small']:
            self.template_small = True
        # 生成anchor
        self.anchors.generate_all_anchors(im_c=self.search_size//2, size=self.size)
        if 'anchor_target' not in cfg:
            cfg['anchor_target'] = {}
        # 生成anchor的信息:cls,reg,mask
        self.anchor_target = AnchorTargetLayer(cfg['anchor_target'])

        # data sets
        if 'datasets' not in cfg:
            raise(Exception('DataSet need "{}"'.format('datasets')))

        self.all_data = []
        start = 0
        self.num = 0
        for name in cfg['datasets']:
            dataset = cfg['datasets'][name]
            dataset['mark'] = name
            dataset['start'] = start
            # 加载数据
            dataset = SubDataSet(dataset)
            dataset.log()
            self.all_data.append(dataset)
            # 数据数量
            start += dataset.num  # real video number
            # 打乱的数据数量
            self.num += dataset.num_use  # the number used for subset shuffle

        # 数据增强data augmentation
        aug_cfg = cfg['augmentation']
        self.template_aug = Augmentation(aug_cfg['template'])
        self.search_aug = Augmentation(aug_cfg['search'])
        self.gray = aug_cfg['gray']
        self.neg = aug_cfg['neg']
        self.inner_neg = 0 if 'inner_neg' not in aug_cfg else aug_cfg['inner_neg']

        self.pick = None  # list to save id for each img
        if 'num' in cfg:  # number used in training for all dataset
            self.num = int(cfg['num'])
        self.num *= num_epoch
        self.shuffle()

        self.infos = {
                'template': self.template_size,
                'search': self.search_size,
                'template_small': self.template_small,
                'gray': self.gray,
                'neg': self.neg,
                'inner_neg': self.inner_neg,
                'crop_size': self.crop_size,
                'anchor_target': self.anchor_target.__dict__,
                'num': self.num // num_epoch
                }
        logger.info('dataset informations: \n{}'.format(json.dumps(self.infos, indent=4)))

4.2.辅助函数

辅助函数主要图像读取,数据查找等。

图像读取:

 def imread(self, path):
        # 数据读取
        img = cv2.imread(path)
        if self.origin_size == self.template_size:
            # 返回图像
            return img, 1.0

        def map_size(exe, size):
            return int(round(((exe + 1) / (self.origin_size + 1) * (size+1) - 1)))
        # 尺寸调整
        nsize = map_size(self.template_size, img.shape[1])
        # 调整图像大小
        img = cv2.resize(img, (nsize, nsize))
        # 返回图像和缩放比例
        return img, nsize / img.shape[1]

数据查找:

    def find_dataset(self, index):
        "查找数据"
        for dataset in self.all_data:
            if dataset.start + dataset.num > index:
                # 返回索引范围内的数据
                return dataset, index - dataset.start

数据打乱:

 def shuffle(self):
        "打乱"
        pick = []
        m = 0
        # 获取数据
        while m < self.num:
            p = []
            for subset in self.all_data:
                sub_p = subset.shuffle()
                p += sub_p
            # 打乱数据
            sample_random.shuffle(p)
            # 将打乱的结果进行拼接
            pick += p
            m = len(pick)
        # 将打乱的结果赋值给pick
        self.pick = pick
        logger.info("shuffle done!")
        logger.info("dataset length {}".format(self.num))

4.3.数据构建

数据构建完成了训练数据的构建,通过getItems完成,主要流程是:

实现代码:

    def __getitem__(self, index, debug=False):
        # 在打乱的结果中找到索引
        index = self.pick[index]
        # 查找得到数据
        dataset, index = self.find_dataset(index)
        # 灰度图
        gray = self.gray and self.gray > random.random()
        # 负样本
        neg = self.neg and self.neg > random.random()
        # 负样本
        if neg:
            # 获取template
            template = dataset.get_random_target(index)
            # 根据设置,从数据生成负样本或随机选择负样本
            if self.inner_neg and self.inner_neg > random.random():
                search = dataset.get_random_target()
            else:
                search = random.choice(self.all_data).get_random_target()
        else:
            # 获得正样本对
            template, search = dataset.get_positive_pair(index)
        # 裁剪图像的中央大小为size的部分
        def center_crop(img, size):
            # 获取图像的形状
            shape = img.shape[1]
            # 若为size,则直接返回
            if shape == size: return img
            # 否则,裁剪中央位置为size大小的图像
            c = shape // 2
            l = c - size // 2
            r = c + size // 2 + 1
            return img[l:r, l:r]
        # 读取模板图像
        template_image, scale_z = self.imread(template[0])
        # 若设置为小模板时,则从模板图像中进行裁剪
        if self.template_small:
            template_image = center_crop(template_image, self.template_size)
        # 读取待搜索图像
        search_image, scale_x = self.imread(search[0])
        # 若存在掩膜并且不是负样本数据
        if dataset.has_mask and not neg:
            # 读取掩膜数据
            search_mask = (cv2.imread(search[2], 0) > 0).astype(np.float32)
        else:
            # 掩膜数据用全零数组替代
            search_mask = np.zeros(search_image.shape[:2], dtype=np.float32)
        # 若裁剪size大于0,对搜索图像和掩膜进行裁剪
        if self.crop_size > 0:
            search_image = center_crop(search_image, self.crop_size)
            search_mask = center_crop(search_mask, self.crop_size)
        # 根据图像大小生成bbox,shape是模板图像中bbox的形状
        def toBBox(image, shape):
            # 图像的大小
            imh, imw = image.shape[:2]
            # 获取shape的宽高
            if len(shape) == 4:
                w, h = shape[2]-shape[0], shape[3]-shape[1]
            else:
                w, h = shape
            # 扩展比例
            context_amount = 0.5
            # 模板尺寸
            exemplar_size = self.template_size  # 127
            # 获取宽高
            wc_z = w + context_amount * (w+h)
            hc_z = h + context_amount * (w+h)
            # 等效边长
            s_z = np.sqrt(wc_z * hc_z)
            # 比例
            scale_z = exemplar_size / s_z
            # 宽高
            w = w*scale_z
            h = h*scale_z
            # 中心点坐标
            cx, cy = imw//2, imh//2
            bbox = center2corner(Center(cx, cy, w, h))
            return bbox
        # 生成模板图像和待搜索图像中的bbox
        template_box = toBBox(template_image, template[1])
        search_box = toBBox(search_image, search[1])
        # 模板数据增强
        template, _, _ = self.template_aug(template_image, template_box, self.template_size, gray=gray)
        # 待搜索图像的数据增强
        search, bbox, mask = self.search_aug(search_image, search_box, self.search_size, gray=gray, mask=search_mask)

        # 生成anchor对应的信息
        cls, delta, delta_weight = self.anchor_target(self.anchors, bbox, self.size, neg)
        if dataset.has_mask and not neg:
            # 掩膜图像
            mask_weight = cls.max(axis=0, keepdims=True)
        else:
            mask_weight = np.zeros([1, cls.shape[1], cls.shape[2]], dtype=np.float32)
        # 模板和搜索图像
        template, search = map(lambda x: np.transpose(x, (2, 0, 1)).astype(np.float32), [template, search])
        # 掩膜结果
        mask = (np.expand_dims(mask, axis=0) > 0.5) * 2 - 1  # 1*H*W
        # 返回结果
        return template, search, cls, delta, delta_weight, np.array(bbox, np.float32), \
               np.array(mask, np.float32), np.array(mask_weight, np.float32)

总结:

  • 图像增强使用了图像裁剪,尺度变换,平移,模糊等方法
  • 我们使用anchors对应的数据作为正负样本完成训练数据的构建

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

あずにゃん

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值