AOT源码解析1

论文阅读

代码分析

1静态数据处理

视频目标分割中的静态数据处理,通常是将单帧静态图像按照旋转、平移、裁剪、缩放、翻转等数据增强操作获得一系列不同的图像,再将这些图像组合在一起当成一段训练视频。除此之外,部分论文还会选择将生成得到的两个图像视频合并重叠在一起,得到一个更为复杂的新训练数据,以此来提升模型的鲁棒性。

1.1引入包

#用于从 __future__ 模块中导入特定的特性,使得这些特性在当前的 Python 版本中可用,即使这些特性在更高版本的 Python 中才被默认引入。
from __future__ import division
import os
#允许查找全文文件
from glob import glob
#用于编码和解码json数据
import json
import random
import cv2
from PIL import Image

import numpy as np
import torch
#Dataset 是一个抽象类,用于定义自定义数据集的接口。通过继承 torch.utils.data.Dataset 并实现特定的方法
from torch.utils.data import Dataset
#包含多种类型的图像变换:缩放、裁剪、归一化等
import torchvision.transforms as TF

import dataloaders.image_transforms as IT

#在 OpenCV 中,某些操作可能会使用多线程来加速处理。cv2.setNumThreads() 函数允许你指定 OpenCV 应该使用多少线程来执行这些操作。
#0表示 OpenCV 应该使用所有可用的硬件线程,它会自动选择最佳的线程数
cv2.setNumThreads(0)

1.2 继承Dataset类

"""
===============================================Dataset用法==============================================================
1、Dataset中常用的特定方法:__init__进行初始化,__len__得到数据数量,__getitem__获取数据。
2、一般流程为:先对dataset进行初始化,得到数据地址、数据名列表等。再使用数据名列表得到数据数量,并将该数量传给len。
3、getitem从len得到索引,索引从0~len-1.然后通过索引获取数据。然后使用Dataloader读取Dataset传出的数据。
------------------------------------------------------------------------------------------------------------------------
"""

1.3 数据初始化

#===========================================将静态图像转换成生成视频,用于预训练===============================================
class StaticTrain(Dataset):
    def __init__(self,
                 root,
                 output_size,
                 seq_len=5,
                 max_obj_n=10,
                 dynamic_merge=True,
                 merge_prob=1.0,
                 aug_type='v1'):  #aug_type是两种不同的图像增强方式
        self.root = root  #数据集根目录,根目录底下为imge、annotation,这两个下面是各个数据集的数据
        self.clip_n = seq_len #序列长度
        self.output_size = output_size   #输出的图像大小
        self.max_obj_n = max_obj_n       #图像中的最大对象数量,用于生成one-hot编码,为每个对象生成唯一标识符

        self.dynamic_merge = dynamic_merge
        self.merge_prob = merge_prob  #合并概率

        self.img_list = list()
        self.mask_list = list()


#==================================================获取数据列表===========================================================
#1、获取数据集名称,所有数据都按特定格式保存
#2、获取所有图像数据的名称,并存为list并确保所有img都有对应的mask

        dataset_list = list() #用于保存使用到的数据集名称
        lines = ['COCO', 'ECSSD', 'MSRA10K', 'PASCAL-S', 'PASCALVOC2012']#可能会用到的数据集,训练自己的数据时,在里面加入自己的数据名
        for line in lines:
            dataset_name = line.strip()#移除字符串开头和末尾的空白字符

            img_dir = os.path.join(root, 'JPEGImages', dataset_name)
            mask_dir = os.path.join(root, 'Annotations', dataset_name)

            img_list = sorted(glob(os.path.join(img_dir, '*.jpg'))) + \
                sorted(glob(os.path.join(img_dir, '*.png')))#搜索img_dir文件夹下的所有.jpg和.png结尾的文件
            mask_list = sorted(glob(os.path.join(mask_dir, '*.png')))#同上

            if len(img_list) > 0:#确保存在数据
                if len(img_list) == len(mask_list):#确保所有img都有对应的mask
                    dataset_list.append(dataset_name)
                    self.img_list += img_list
                    self.mask_list += mask_list
                    print(f'\t{dataset_name}: {len(img_list)} imgs.')
                else:
                    print(
                        f'\tPreTrain dataset {dataset_name} has {len(img_list)} imgs and {len(mask_list)} annots. Not match! Skip.'
                    )
            else:
                print(
                    f'\tPreTrain dataset {dataset_name} doesn\'t exist. Skip.')

        print(
            f'{len(self.img_list)} imgs are used for PreTrain. They are from {dataset_list}.'
        )

#====================================================初始化数据增强=============================================================
#1、aug_type定义了增强模式。设为v1时,轻微改变图像的亮度、饱和度等;设为v2时,随机较大的改变图像亮度、饱和度,并随机灰度化图像,随机添加噪声;设为其他值时报错
#2、随机旋转、平移、缩放、剪切、插值、填充颜色
#3、随机裁剪和缩放
#4、将图像转换为张量
#5、将标签转换为one-hot编码
#6、进行归一化

        self.aug_type = aug_type

        self.pre_random_horizontal_flip = IT.RandomHorizontalFlip(0.5)

        self.random_horizontal_flip = IT.RandomHorizontalFlip(0.3)

        if self.aug_type == 'v1':
            self.color_jitter = TF.ColorJitter(0.1, 0.1, 0.1, 0.03)
        elif self.aug_type == 'v2':
            self.color_jitter = TF.RandomApply(
                [TF.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8)
            self.gray_scale = TF.RandomGrayscale(p=0.2)
            self.blur = TF.RandomApply([IT.GaussianBlur([.1, 2.])], p=0.3)
        else:
            assert NotImplementedError

        self.random_affine = IT.RandomAffine(degrees=20,
                                             translate=(0.1, 0.1),
                                             scale=(0.9, 1.1),
                                             shear=10,
                                             resample=Image.BICUBIC,
                                             fillcolor=(124, 116, 104))
        #scale表示裁剪的区域为80%~100%,ratio: 裁剪区域的宽高比范围,图像缩放时使用的插值方法。Image.BICUBIC 是双三次插值,适用于缩小和放大图像。
        base_ratio = float(output_size[1]) / output_size[0]
        self.random_resize_crop = IT.RandomResizedCrop(
            output_size, (0.8, 1),
            ratio=(base_ratio * 3. / 4., base_ratio * 4. / 3.),
            interpolation=Image.BICUBIC)
        self.to_tensor = TF.ToTensor()
        #max_obj_n代表数据集中最大的对象数量,用于确定one-hot编码向量的长度,使每个对象都有唯一编码
        self.to_onehot = IT.ToOnehot(max_obj_n, shuffle=True)
        self.normalize = TF.Normalize((0.485, 0.456, 0.406),
                                      (0.229, 0.224, 0.225))

1.4 获取数据长度

    def __len__(self):
        return len(self.img_list)

1.5 获取数据

  • 1 总体框架
#========================================获取生成视频,并将图片样本进行合并,生成新的训练数据=====================================
#1、通过seq_num定义生成的视频数据长度
#2、通过旋转、裁剪等数据增强操作生成新视频
#3、若dynamic_merge=TRUE和merge_prob=1,则进行样本合并,生成新的数据
    def __getitem__(self, idx):
        sample1 = self.sample_sequence(idx) #获取生成的视频
        if self.dynamic_merge and (sample1['meta']['obj_num'] == 0
                                   or random.random() < self.merge_prob):
            #选取随机索引,并且使得随机到的索引不等于当前索引
            rand_idx = np.random.randint(len(self.img_list))
            while (rand_idx == idx):
                rand_idx = np.random.randint(len(self.img_list))

            #获取第二个样本
            sample2 = self.sample_sequence(rand_idx)

            sample = self.merge_sample(sample1, sample2)
        else:
            sample = sample1

        return sample

    def merge_sample(self, sample1, sample2, min_obj_pixels=100):
        return _merge_sample(sample1, sample2, min_obj_pixels, self.max_obj_n)
  • 2 生成视频
    def load_image_in_PIL(self, path, mode='RGB'):
        img = Image.open(path)
        img.load()  # Very important for loading large image,调用 load 方法将图像数据加载到内存中
        return img.convert(mode)

    def sample_sequence(self, idx):
        #读取图像,并读取为PIL格式
        img_pil = self.load_image_in_PIL(self.img_list[idx], 'RGB')
        mask_pil = self.load_image_in_PIL(self.mask_list[idx], 'P')

        frames = []
        masks = []

        #随机水平翻转
        img_pil, mask_pil = self.pre_random_horizontal_flip(img_pil, mask_pil)
        # img_pil, mask_pil = self.pre_random_vertical_flip(img_pil, mask_pil)

        #===============================================通过循环生成视频===================================================
        for i in range(self.clip_n):
            img, mask = img_pil, mask_pil

            if i > 0:
                img, mask = self.random_horizontal_flip(img, mask)
                img, mask = self.random_affine(img, mask)

            img = self.color_jitter(img)

            img, mask = self.random_resize_crop(img, mask)

            if self.aug_type == 'v2':
                img = self.gray_scale(img)
                img = self.blur(img)

            #---------------------掩码处理--------------------------
            #将mask转为numpy数组
            mask = np.array(mask, np.uint8)
            if i == 0:
                mask, obj_list = self.to_onehot(mask)
                obj_num = len(obj_list)
            else:
                #返回的maskshape为(max_obj_n+1,height,weight),其中0通道为背景掩码,剩下每一个通道代表一个对象的mask
                mask, _ = self.to_onehot(mask, obj_list)
            mask = torch.argmax(mask, dim=0, keepdim=True)

            frames.append(self.normalize(self.to_tensor(img)))#frames存储原始视频帧
            masks.append(mask) #masks存储模板值

        sample = {
            'ref_img': frames[0], #参考图像,即序列中的第一个图像
            'prev_img': frames[1],
            'curr_img': frames[2:],
            'ref_label': masks[0],
            'prev_label': masks[1],
            'curr_label': masks[2:]
        }
        sample['meta'] = {
            'seq_name': self.img_list[idx],
            'frame_num': 1,
            'obj_num': obj_num
        }

        return sample
  • 3 样本合并
def _get_images(sample):
    return [sample['ref_img'], sample['prev_img']] + sample['curr_img']

def _get_labels(sample):
    return [sample['ref_label'], sample['prev_label']] + sample['curr_label']


#===============================================将两个样本融合为一个========================================================
def _merge_sample(sample1, sample2, min_obj_pixels=100, max_obj_n=10):

    #提取所有的图像和掩码
    sample1_images = _get_images(sample1)
    sample2_images = _get_images(sample2)

    sample1_labels = _get_labels(sample1)
    sample2_labels = _get_labels(sample2)

    #obj_idx: 一个用于索引对象的张量,范围从 0 到 max_obj_n * 2。
    #selected_idx 和 selected_obj: 用于存储选择的对象索引和对象本身。
    obj_idx = torch.arange(0, max_obj_n * 2 + 1).view(max_obj_n * 2 + 1, 1, 1)
    selected_idx = None
    selected_obj = None

    all_img = []
    all_mask = []
    #========================通过合并不同样本,创建新的训练样本======================
    for idx, (s1_img, s2_img, s1_label, s2_label) in enumerate(
            zip(sample1_images, sample2_images, sample1_labels,
                sample2_labels)):
        s2_fg = (s2_label > 0).float()
        s2_bg = 1 - s2_fg
        merged_img = s1_img * s2_bg + s2_img * s2_fg
        merged_mask = s1_label * s2_bg.long() + (
            (s2_label + max_obj_n) * s2_fg.long())
        merged_mask = (merged_mask == obj_idx).float()
        if idx == 0:
            after_merge_pixels = merged_mask.sum(dim=(1, 2), keepdim=True)
            selected_idx = after_merge_pixels > min_obj_pixels
            selected_idx[0] = True
            obj_num = selected_idx.sum().int().item() - 1
            selected_idx = selected_idx.expand(-1,
                                               s1_label.size()[1],
                                               s1_label.size()[2])
            if obj_num > max_obj_n:
                selected_obj = list(range(1, obj_num + 1))
                random.shuffle(selected_obj)
                selected_obj = [0] + selected_obj[:max_obj_n]

        merged_mask = merged_mask[selected_idx].view(obj_num + 1,
                                                     s1_label.size()[1],
                                                     s1_label.size()[2])
        if obj_num > max_obj_n:
            merged_mask = merged_mask[selected_obj]
        merged_mask[0] += 0.1
        merged_mask = torch.argmax(merged_mask, dim=0, keepdim=True).long()

        all_img.append(merged_img)
        all_mask.append(merged_mask)

    sample = {
        'ref_img': all_img[0],
        'prev_img': all_img[1],
        'curr_img': all_img[2:],
        'ref_label': all_mask[0],
        'prev_label': all_mask[1],
        'curr_label': all_mask[2:]
    }
    sample['meta'] = sample1['meta']
    sample['meta']['obj_num'] = min(obj_num, max_obj_n)
    return sample

2 视频数据处理

这里的做法是在原始视频段上截取一段随机长度的子片段,然后再将该子片段随机分成4段更小的片段,小片段的长度不一。将随机得到的小片段进行数据增强和数据合并(与1 静态数据处理中的一致)后,作为训练数据训练model。ref_frame是参考帧,半监督视频分割网络依据参考帧进行分割;curr_frame是当前帧,用于分割预测;pre_frame是当前帧的前一帧。这里值得注意的是:参考帧要求前景mask不能过小,同时必须包含当前帧中的所有对象。

2.1 数据初始化-父类VOSTrain

#========================================从视频序列中采样序列帧,并进行数据增强和预处理==========================================
class VOSTrain(Dataset):
    def __init__(self,
                 image_root,
                 label_root,
                 imglistdic,
                 transform=None,
                 rgb=True,
                 repeat_time=1,
                 rand_gap=3,
                 seq_len=5,
                 rand_reverse=True,
                 dynamic_merge=True,
                 enable_prev_frame=False,
                 merge_prob=0.3,
                 max_obj_n=10):
        self.image_root = image_root #图像根目录
        self.label_root = label_root #label根目录
        self.rand_gap = rand_gap #随机采样视频间隔
        self.seq_len = seq_len #视频段长度
        self.rand_reverse = rand_reverse #随机逆序遍历列表,原代码的概率是0.5
        self.repeat_time = repeat_time #从视频段中重复采取视频段的次数
        self.transform = transform #图像变换
        self.dynamic_merge = dynamic_merge #动态合并
        self.merge_prob = merge_prob #动态合并概率
        self.enable_prev_frame = enable_prev_frame #使用前一帧
        self.max_obj_n = max_obj_n #图像的最大对象数量
        self.rgb = rgb #rgb图像
        self.imglistdic = imglistdic #里面包含各个视频名称,每个视频名称对应的是相应的图像名列表和label名列表
        self.seqs = list(self.imglistdic.keys()) #包含的是各个视频名称
        print('Video Num: {} X {}'.format(len(self.seqs), self.repeat_time))

2.2 数据初始化-子类DAVIS2017_Train

  • DAVIS数据集格式
    在这里插入图片描述
    在这里插入图片描述
class DAVIS2017_Train(VOSTrain):
    def __init__(self,
                 split=['train'],
                 root='./DAVIS',
                 transform=None,
                 rgb=True,
                 repeat_time=1,
                 full_resolution=True,
                 year=2017,
                 rand_gap=3,
                 seq_len=5,
                 rand_reverse=True,
                 dynamic_merge=True,
                 enable_prev_frame=False,
                 max_obj_n=10,
                 merge_prob=0.3):
        #选择不同分辨率,这个主要是DAVIS数据集自带的分辨率类别。
        #具体DAVIS结构见CSDN图
        if full_resolution:
            resolution = 'Full-Resolution'
            if not os.path.exists(os.path.join(root, 'JPEGImages',
                                               resolution)):
                print('No Full-Resolution, use 480p instead.')
                resolution = '480p'
        else:
            resolution = '480p'
        image_root = os.path.join(root, 'JPEGImages', resolution)
        label_root = os.path.join(root, 'Annotations', resolution)
        seq_names = [] #存储的是视频名称
        for spt in split:
            #这里得到的地址是:root/ImageSets/2017/train.txt
            with open(os.path.join(root, 'ImageSets', str(year),
                                   spt + '.txt')) as f:
                seqs_tmp = f.readlines()
            #创建一个新的列表,其中包含seqs_tmp,列表中每个元素的副本,但已经去除了每个元素两端的空白字符。
            seqs_tmp = list(map(lambda elem: elem.strip(), seqs_tmp))
            seq_names.extend(seqs_tmp)
        imglistdic = {}
        for seq_name in seq_names:
            #os.path.join(image_root, seq_name)存储的是每个视频文件夹的地址
            #images和labels存储的是各个图像名列表
            images = list(
                np.sort(os.listdir(os.path.join(image_root, seq_name))))
            labels = list(
                np.sort(os.listdir(os.path.join(label_root, seq_name))))
            imglistdic[seq_name] = (images, labels) #存储的是每个视频名称类别对应一个图像路径和一个标签路径

        super(DAVIS2017_Train, self).__init__(image_root,
                                              label_root,
                                              imglistdic,
                                              transform,
                                              rgb,
                                              repeat_time,
                                              rand_gap,
                                              seq_len,
                                              rand_reverse,
                                              dynamic_merge,
                                              enable_prev_frame,
                                              merge_prob=merge_prob,
                                              max_obj_n=max_obj_n)

2.3 获得数据长度

    def __len__(self):
        return int(len(self.seqs) * self.repeat_time)

2.4 获得数据

  • 主体框架
    def __getitem__(self, idx):
        #获得样本
        sample1 = self.sample_sequence(idx)

        #进行样本合并
        if self.dynamic_merge and (sample1['meta']['obj_num'] == 0
                                   or random.random() < self.merge_prob):
            rand_idx = np.random.randint(len(self.seqs))
            while (rand_idx == (idx % len(self.seqs))):#确保随机生成的索引与当前索引不重复
                rand_idx = np.random.randint(len(self.seqs))

            sample2 = self.sample_sequence(rand_idx)

            sample = self.merge_sample(sample1, sample2)
        else:
            sample = sample1

        return sample

    def merge_sample(self, sample1, sample2, min_obj_pixels=100):
        return _merge_sample(sample1, sample2, min_obj_pixels, self.max_obj_n)
  • 获取样本
    def sample_sequence(self, idx):
        #通过取余的方法,选取不同的视频类别,并读取数据
        idx = idx % len(self.seqs)
        seqname = self.seqs[idx]
        imagelist, lablist = self.imglistdic[seqname]
        frame_num = len(imagelist) #视频中包含的帧数

        #随机逆序遍历列表
        if self.rand_reverse:
            imagelist, lablist = self.reverse_seq(imagelist, lablist)

        is_consistent = False #一致性开关,这个用于控制当前帧和前一帧的对象在参考帧中都存在
        max_try = 5 #最大循环次数
        try_step = 0 #当前循环次数
        #只要没有找到一致的帧序列且尝试次数小于最大尝试次数,就继续循环。
        #为的是通过随机采样和一致性检查,从视频序列中选取一组有意义的帧,这些帧将被用于后续的视频对象分割任务。
        while (is_consistent is False and try_step < max_try):
            try_step += 1

            # ==============================================随机生成当前间隔===============================================
            #在这里seq_len为5,get_curr_gaps会随机生成4(seq_len-1)次当前gap,这些gap是1~4中随机生成。total_gap是所有当前gap的和,相当于最终训练时获得的视频长度
            curr_gaps, total_gap = self.get_curr_gaps(self.seq_len - 1)

            #==================如果 self.enable_prev_frame 为真,表示允许随机采样前一帧。====================
            if self.enable_prev_frame:  # prev frame is randomly sampled
                # get prev frame
                #这里的做法是:视频总长度-训练时获得的视频长度,然后在这个范围内随机选取一个值。这样做的原因就是为了当最终裁取的视频段不会逾越过原视频
                prev_index = self.get_prev_index(lablist, total_gap)#根据总间隔获取前一帧的索引。
                #获取图像和标签。
                prev_image, prev_label = self.get_image_label(
                    seqname, imagelist, lablist, prev_index)
                #提取出唯一的对象标识符
                prev_objs = list(np.unique(prev_label))

                # get curr frames
                #根据前一帧索引和间隔生成当前帧的索引列表。即以前一帧为star,以star+total_gap结尾,截取一段视频段。然后这段视频段被分成seq_len-1个小段
                curr_indices = self.get_curr_indices(lablist, prev_index,
                                                     curr_gaps)
                #遍历 curr_indices,获取当前帧的图像和标签,并收集所有对象
                curr_images, curr_labels, curr_objs = [], [], []
                for curr_index in curr_indices:
                    #获取图像和标签
                    curr_image, curr_label = self.get_image_label(
                        seqname, imagelist, lablist, curr_index)
                    c_objs = list(np.unique(curr_label))
                    curr_images.append(curr_image)
                    curr_labels.append(curr_label)
                    curr_objs.extend(c_objs)

                #收集前一帧和当前帧的所有对象
                objs = list(np.unique(prev_objs + curr_objs))

                start_index = prev_index
                end_index = max(curr_indices)
                # get ref frame
                _try_step = 0
                ref_index = self.get_ref_index_v2(seqname, lablist)#参考帧随机采样
                #如果参考帧索引在前一帧和当前帧索引范围内,重新生成,直到找到一个合适的参考帧。
                while (ref_index > start_index and ref_index <= end_index
                       and _try_step < max_try):
                    _try_step += 1
                    ref_index = self.get_ref_index_v2(seqname, lablist)#随机获取参考帧索引,这里要求参考帧的前景mask不能过小
                ref_image, ref_label = self.get_image_label(
                    seqname, imagelist, lablist, ref_index)
                ref_objs = list(np.unique(ref_label))
            else:  # prev frame is next to ref frame,如果 self.enable_prev_frame 为假,表示前一帧是参考帧的下一帧。
                # get ref frame,直接使用参考帧索引获取当前帧的索引列表和图像标签。
                ref_index = self.get_ref_index_v2(seqname, lablist)

                ref_image, ref_label = self.get_image_label(
                    seqname, imagelist, lablist, ref_index)
                ref_objs = list(np.unique(ref_label))

                # get curr frames
                curr_indices = self.get_curr_indices(lablist, ref_index,
                                                     curr_gaps)
                curr_images, curr_labels, curr_objs = [], [], []
                for curr_index in curr_indices:
                    curr_image, curr_label = self.get_image_label(
                        seqname, imagelist, lablist, curr_index)
                    c_objs = list(np.unique(curr_label))
                    curr_images.append(curr_image)
                    curr_labels.append(curr_label)
                    curr_objs.extend(c_objs)

                objs = list(np.unique(curr_objs))
                prev_image, prev_label = curr_images[0], curr_labels[0]
                curr_images, curr_labels = curr_images[1:], curr_labels[1:]

            is_consistent = True#假设帧序列是一致的。
            #遍历所有对象,检查它们是否在参考帧中也存在。如果有任何对象在参考帧中不存在,则将 is_consistent 设置为假,并跳出循环。
            for obj in objs:
                if obj == 0:
                    continue
                if obj not in ref_objs:
                    is_consistent = False
                    break

        # get meta info
        obj_num = list(np.sort(ref_objs))[-1]

        sample = {
            'ref_img': ref_image,
            'prev_img': prev_image,
            'curr_img': curr_images,
            'ref_label': ref_label,
            'prev_label': prev_label,
            'curr_label': curr_labels
        }
        sample['meta'] = {
            'seq_name': seqname,
            'frame_num': frame_num,
            'obj_num': obj_num
        }

        if self.transform is not None:
            sample = self.transform(sample)

        return sample

  • get_curr_gaps
    在这里插入图片描述
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值