从零实现无监督光流pipline(1):数据dataset部分

前言

宝宝心里苦呀,调了几天的官方的tensorflow代码怎么都弄不通,后来一想,本科的毕业设计好像没有精度要求吧,我不如自己去搭建一个pipline,这样以后我也可以在这个基础上进行改进嘛。并且最理想的就是可以去copy下大佬写的代码嘛,但是很难受,很多都跑不通,终于我找到一个可以跑通的工程了:ARFlow,很感谢这个作者,让我在黑暗中看到了光,我的世界没有别人,你是未来的我给我的帮助嘛,感谢引领!

Here we go!

数据dataset基类

我们在构建不同数据dataset的时候,要先去定义一个基类,之后面对不同的数据集定义不同的dataset类,我们这里借助ARFlow的代码进行,下面是注释后的代码。

import imageio
import numpy as np
import random
from path import Path
from abc import abstractmethod, ABCMeta
from torch.utils.data import Dataset
from utils.flow_utils import load_flow

class ImgSeqDataset(Dataset, metaclass=ABCMeta):
    def __init__(self, root, n_frames, input_transform=None, co_transform=None,
                 target_transform=None, ap_transform=None):
        # 数据集路径
        self.root = Path(root)
        # 图像序列帧数,现在很多论文开始使用3帧了
        self.n_frames = n_frames
        # 对输入帧进行变换
        self.input_transform = input_transform
        # 协同变换
        self.co_transform = co_transform
        # 姿态变换
        self.ap_transform = ap_transform
        # 对目标进行变换,这里指光流和mask
        self.target_transform = target_transform
        # 通过collect_samples方法得到图像
        self.samples = self.collect_samples()

    @abstractmethod
    def collect_samples(self):
        # 抽象方法,子类需要实现用于收集样本信息的方法
        pass

    def _load_sample(self, s):
        # 加载单个样本的方法

        # 从样本字典中获取图像序列路径列表
        images = s['imgs']

        # 读取图像序列中的每一帧图像,self.root / p 表示构建完整的数据路径
        # 将每一帧图像加载为 NumPy 数组,并将数据类型转换为 float32
        images = [imageio.imread(self.root / p).astype(np.float32) for p in images]

        # 初始化目标字典
        target = {}

        # 如果样本中包含光流信息,则加载光流数据
        if 'flow' in s:
            target['flow'] = load_flow(self.root / s['flow'])

        # 如果样本中包含掩码信息,则加载并处理掩码数据
        if 'mask' in s:
            mask = imageio.imread(self.root / s['mask']).astype(np.float32) / 255.
            # 如果掩码数据是三维的,只选择第一个通道
            if len(mask.shape) == 3:
                mask = mask[:, :, 0]
            # 将掩码数据的通道维度扩展,变成 (H, W, 1) 的形状
            target['mask'] = np.expand_dims(mask, -1)

        # 返回加载的图像序列列表和目标字典
        return images, target

    def __len__(self):
        # 返回数据集中样本的数量
        return len(self.samples)

    def __getitem__(self, idx):
        # 获取指定索引处的样本
        images, target = self._load_sample(self.samples[idx])

        if self.co_transform is not None:
            # 如果定义了协同变换,则应用协同变换
            images, _ = self.co_transform(images, {})
        if self.input_transform is not None:
            # 如果定义了输入变换,则应用输入变换
            images = [self.input_transform(i) for i in images]
        data = {'img{}'.format(i + 1): p for i, p in enumerate(images)}

        if self.ap_transform is not None:
            # 如果定义了姿势估计变换,则应用姿势估计变换
            imgs_ph = self.ap_transform(
                [data['img{}'.format(i + 1)].clone() for i in range(self.n_frames)])
            for i in range(self.n_frames):
                data['img{}_ph'.format(i + 1)] = imgs_ph[i]

        if self.target_transform is not None:
            # 如果定义了目标变换,则应用目标变换
            for key in self.target_transform.keys():
                target[key] = self.target_transform[key](target[key])
        data['target'] = target
        return data

这里我们看到定义一个图像对的dataset,然后大多数都在注释中,这里说点细节或者需要知道的一些知识:

from abc import abstractmethod, ABCMeta

先说下这个包的作用吧,在给定的代码中,from abc import abstractmethod, ABCMeta 引入了 Python 标准库中的 abc 模块,其中包括两个重要的元素:abstractmethod 和 ABCMeta。

abstractmethod 是一个装饰器(decorator),用于声明一个方法是抽象方法。抽象方法是在基类中声明但没有提供具体实现的方法,留待子类去实现。在抽象类中,至少包含一个抽象方法。子类必须实现所有抽象方法,否则无法被实例化。

ABCMeta 是一个元类,用于创建抽象基类(Abstract Base Class,ABC)。元类是一种类的类,它定义了类的行为,如何创建类的实例以及类与其实例的关系。ABCMeta 是 Python 中用于创建抽象基类的特殊元类。通过使用 ABCMeta 元类,类可以被定义为抽象基类,并且可以包含抽象方法。
在给定的代码中,这两个元素被用于定义一个抽象基类 ImgSeqDataset。通过在类定义时使用 metaclass=ABCMeta,指定了该类的元类为 ABCMeta,从而使 ImgSeqDataset 成为抽象基类。同时,通过使用 @abstractmethod 装饰器,声明了一个抽象方法 collect_samples,该方法在子类中必须被实现。

之后我们看这个加载数据的函数,在Python中,以单个下划线 _ 开头的命名约定通常表示该成员是一个"内部"成员,即它是类的内部使用或是模块的内部实现的一部分。这并不是强制规定,而是一种编码风格的约定,旨在向其他开发者传达该成员不应该被直接访问或使用,而是被认为是类的内部实现的一部分。

继承这个基类的抽象类必须要实现collect_samples方法,因为对于不同的数据集,格式是不一样的,必然使用的方法也不一样呀,这里注意一下,是返回一个样本字典。


定义dataset的时候,我们必须定义的有len和gettiem,前者这里不说了,这里主要说一下后者。

我们通过collect_samples可以获取到所有的样本字典,在getitem中,先通过加载样本获取到图像数据和目标字典,之后构建类似于 {‘img1’: image1, ‘img2’: image2, ‘img3’: image3} 的键值对。这样的数据结构可以用于在后续的处理中,更方便地引用和操作图像序列中的每一帧。之后在进行图像变换。

这里注意一下,先将图像读取进来,之后是以字典的形式保存。

其实论文中是先在raw数据,即原始数据,然后才在标准数据集上进行训练,因为是无监督嘛,所以自然可以使用全部的帧,但是这里有个问题呀,就是你可以使用全部的数据,自然也是可以包含了测试数据了,这有点不公平呀,所以大多数的文章好像都不是这么干的呀,这也不太现实呀,所以我这里只制作标准的数据集,sintel,kittil12/15,其实就是这两个有说服了,flychairs也算一个,其实每个文章的方法都不一样呀,对比的话我感觉不太公平,但是没办法就是最后的结果吗?

sintel数据dataset

原工程中有raw,我们这里舍弃了这个,后面会使用flychairs进行预训练,当然这个预训练也是无监督的。下面直接给出代码,注释的很清晰了:

class Sintel(ImgSeqDataset):
    def __init__(self, root, n_frames=2, type='clean', split='training',
                 subsplit='trainval', with_flow=True, ap_transform=None,
                 transform=None, target_transform=None, co_transform=None):
        """
        Sintel 数据集类,用于加载 Sintel 数据集的图像序列。

        参数:
        - root (str): 数据集根目录的路径。
        - n_frames (int): 图像序列的帧数,默认为2。
        - type (str): 数据集类型,'clean' 或 'final'。
        - split (str): 数据集分割,'training' 或 'test'。
        - subsplit (str): 数据集子分割,'trainval'、'train' 或 'val'。
        - with_flow (bool): 是否包含光流信息,默认为 True。
        - ap_transform (callable): 姿势估计变换函数,默认为 None。
        - transform (callable): 输入变换函数,默认为 None。
        - target_transform (callable): 目标变换函数,默认为 None。
        - co_transform (callable): 协同变换函数,默认为 None。
        """
        self.dataset_type = type
        self.with_flow = with_flow
        # split为training或者test
        self.split = split
        self.subsplit = subsplit
        self.training_scene = ['alley_1', 'ambush_4', 'ambush_6', 'ambush_7', 'bamboo_2',
                               'bandage_2', 'cave_2', 'market_2', 'market_5', 'shaman_2',
                               'sleeping_2', 'temple_3']  # 非官方的训练-验证拆分

        # 构造数据集根路径
        root = Path(root) / split
        super(Sintel, self).__init__(root, n_frames, input_transform=transform,
                                     target_transform=target_transform,
                                     co_transform=co_transform, ap_transform=ap_transform)
    # 必须要去写的方法
    def collect_samples(self):
        '''收集样本信息的方法,返回一个包含样本信息的列表。'''
        # 图像文件夹
        img_dir = self.root / Path(self.dataset_type)
        # 光流文件夹
        flow_dir = self.root / 'flow'
        # 确保图像目录和光流目录存在
        assert img_dir.isdir() and flow_dir.isdir()

        samples = []
        # 结果是一个按文件路径名排序的光流场文件列表
        for flow_map in sorted((self.root / flow_dir).glob('*/*.flo')):
            # 用于将文件路径分解为场景(scene)和文件名(filename)等信息
            info = flow_map.splitall()
            scene, filename = info[-2:]
            fid = int(filename[-8:-4])

            # 根据分割和子分割规则选择样本
            if self.split == 'training' and self.subsplit != 'trainval':
                if self.subsplit == 'train' and scene not in self.training_scene:
                    continue
                if self.subsplit == 'val' and scene in self.training_scene:
                    continue

            # 构建样本字典,包含图像序列的文件路径列表,这里是一个序列
            s = {'imgs': [img_dir / scene / 'frame_{:04d}.png'.format(fid + i) for i in
                          range(self.n_frames)]}
            try:
                # 检查所有帧图像文件是否存在
                assert all([p.isfile() for p in s['imgs']])

                # 如果需要光流信息,则添加到样本字典中
                if self.with_flow:
                    if self.n_frames == 3:
                        # 对于 img1 img2 img3,只评估 flow_23
                        s['flow'] = flow_dir / scene / 'frame_{:04d}.flo'.format(fid + 1)
                    elif self.n_frames == 2:
                        # 对于 img1 img2,评估 flow_12
                        s['flow'] = flow_dir / scene / 'frame_{:04d}.flo'.format(fid)
                    else:
                        raise NotImplementedError(
                            'n_frames {} with flow or mask'.format(self.n_frames))

                    # 检查光流文件是否存在
                    if self.with_flow:
                        assert s['flow'].isfile()
            except AssertionError:
                print('Incomplete sample for: {}'.format(s['imgs'][0]))
                continue

            # 将当前样本添加到样本列表中
            samples.append(s)

        # 返回所有样本的列表
        return samples

这里面的training_scene要注意一下,如果我们选择trainval是不用管这个参数的,我也试了一下,发现所有的training数据都会用到的,然后下面是我的一个测试的代码吧:

from torch.utils.data import DataLoader

from flow_datasets import  Sintel
# 假设你的 Sintel 数据集位于'/path/to/sintel'目录下
sintel_dataset = Sintel(root='D:/AI_Do/data/MPI-Sintel', n_frames=2, type='clean', split='training', subsplit='trainval', with_flow=True)

# 实例化 DataLoader,用于批量加载数据
data_loader = DataLoader(sintel_dataset, batch_size=1, shuffle=True, num_workers=0)

# 迭代数据集中的批次
for batch in data_loader:
    print(batch)

    break

flychairs数据dataset

kitti2012/2015我还没有下完,就先把这个弄了吧,ARFlow中并没有这个代码,这里我们尝试借鉴RAFT的代码写一下:

class Flychairs(ImgSeqDataset):
    def __init__(self, root, n_frames=2, subsplit='trainval',split = 'training', split_file='FlyingChairs_train_val.txt',with_flow=True, ap_transform=None,
                 transform=None, target_transform=None, co_transform=None):
        """
        Sintel 数据集类,用于加载 Sintel 数据集的图像序列。

        参数:
        - root (str): 数据集根目录的路径:D:\AI_Do\data\FlyingChairs_release
        - n_frames (int): 图像序列的帧数,默认为2。
        - split(str): 分割形式,training 或 validation。
        - split_file (str): 分割文件。
        - subsplit (str): 数据集子分割,'trainval'、'train' 或 'val'。
        - with_flow (bool): 是否包含光流信息,默认为 True。
        - ap_transform (callable): 姿势估计变换函数,默认为 None。
        - transform (callable): 输入变换函数,默认为 None。
        - target_transform (callable): 目标变换函数,默认为 None。
        - co_transform (callable): 协同变换函数,默认为 None。
        """
        self.with_flow = with_flow
        self.subsplit = subsplit
        self.split = split
        self.root = root
        # 构造分割文件路径
        self.separate_file = os.path.join(root, split_file)
        # 构建数据文件夹路径
        self.data_dir = os.path.join(root, 'data')
        # 确保数据文件夹存在
        assert os.path.isdir(self.data_dir)

        super(Flychairs, self).__init__(root, n_frames, input_transform=transform,
                                     target_transform=target_transform,
                                     co_transform=co_transform, ap_transform=ap_transform)

    def collect_samples(self):
        samples = []
        # 获取所有的ppm(图像)和flo(光流图)文件
        images = sorted(glob(osp.join(self.data_dir, '*.ppm')))
        flows = sorted(glob(osp.join(self.data_dir, '*.flo')))
        assert (len(images) // 2 == len(flows))
        # 打开分割文件,全是数字1或者2,1代表训练,2代表验证
        split_list = np.loadtxt(self.separate_file, dtype=np.int32)
        for i in range(len(flows)):
            xid = split_list[i]
            # 根据split的不同构建不同的数据集
            if (self.split == 'training' and xid == 1) or (self.split == 'validation' and xid == 2):
                # 构建样本字典,包含图像序列的文件路径列表,这里是一个序列
                s = {'imgs': [images[2 * i], images[2 * i + 1]]}
                try:
                    # 检查所有帧图像文件是否存在
                    assert all([os.path.isfile(p) for p in s['imgs']])

                    # 如果需要光流信息,则添加到样本字典中
                    if self.with_flow:
                        if self.n_frames == 3:
                            # 对于 img1 img2 img3,只评估 flow_23
                            s['flow'] = flows[i+1]
                        elif self.n_frames == 2:
                            # 对于 img1 img2,评估 flow_12
                            s['flow'] = flows[i]
                        else:
                            raise NotImplementedError(
                                'n_frames {} with flow or mask'.format(self.n_frames))
                        # 检查光流文件是否存在
                        if self.with_flow:
                            assert os.path.isfile(s['flow'])

                except AssertionError:
                    print('Incomplete sample for: {}'.format(s['imgs'][0]))
                    continue
                # 将当前样本添加到样本列表中
                samples.append(s)
        # 返回所有样本的列表
        return samples

注释写的很清晰了,然后说一个细节,就是我看了验证集的部分,对于flychairs和sintel是不一样的,前者是因为没有test集,所以就将一部分不放在训练里,而后者是有专门的test集的,所以就是training没有分一下,直接全部进行训练了,并且将final和clean都放在一起训练了,放在一起训练的数据是多的,我们也先这么做吧。

测试代码与上面的差不多,这里就不放上来了。

kitti数据dataset

这个数据集下的很慢,但是现在也是下完了,现在开始吧,这里先说一下,在ARFlow中是将kitti所有的原始数据都下载下来了,我们这里并没有,而是使用的是12/15的多视角的数据集,就是两个10多G的数据集,在ARFlow中作者说直接使用这个训练是差不多的,我们这里先来查看这个数据集。下载完毕后呈现的就是现在的情况
在这里插入图片描述
其中的multiview的话,全是图像,没有标注数据,我们会使用这个进行无监督训练的,如在第二个文件夹下的image_2的文件夹下有4200张图片。而第一个和第三个是有标注文件的,但数据会少很多,如在第一个文件夹下的image_2的文件夹仅有400张图像。那么我们现在开始学习代码吧!并且我们这里提一下image_0与image_1是灰度的图像,所以我们这里主要使用image_2和image_3。

在ARFlow的工程中定义了三个dataset的类,分别是raw数据,无监督训练,只验证三个dataset,我们这里因为不使用raw数据,所以只考虑后面两个。

在看代码之前,我们先把标签数据复制一下过来,但是我们训练的时候跳过9~12帧,为了后面使用这些数据进行验证的时候用的,因为测试集并没有公开标注,这也是现在主流的做法。
在这里插入图片描述

class KITTIFlowMV(ImgSeqDataset):
    """
    这个数据集仅用于无监督训练
    """

    def __init__(self, root, n_frames=2,
                 transform=None, co_transform=None, ap_transform=None, ):
        super(KITTIFlowMV, self).__init__(root, n_frames,
                                          input_transform=transform,
                                          co_transform=co_transform,
                                          ap_transform=ap_transform)

    def collect_samples(self):
        # 设置光流图路径
        flow_occ_dir = 'flow_' + 'occ'
        assert (self.root / flow_occ_dir).isdir()

        # 设置左右图像路径
        img_l_dir, img_r_dir = 'image_2', 'image_3'
        assert (self.root / img_l_dir).isdir() and (self.root / img_r_dir).isdir()

        samples = []
        # 遍历光流图目录下所有以 '.png' 结尾的文件
        for flow_map in sorted((self.root / flow_occ_dir).glob('*.png')):
            flow_map = flow_map.basename()
            root_filename = flow_map[:-7]

            for img_dir in [img_l_dir, img_r_dir]:
                # 获取与光流图文件名相关的图像文件列表
                img_list = (self.root / img_dir).files('*{}*.png'.format(root_filename))
                img_list.sort()

                for st in range(0, len(img_list) - self.n_frames + 1):
                    # 获取图像序列
                    seq = img_list[st:st + self.n_frames]
                    sample = {}
                    sample['imgs'] = []
                    for i, file in enumerate(seq):
                        # 解析文件名中的帧号
                        frame_id = int(file[-6:-4])
                        # 如果帧号在 9 和 12 之间(包括 9 和 12),终止循环
                        if 12 >= frame_id >= 9:
                            break
                        sample['imgs'].append(self.root.relpathto(file))
                    # 如果采样到了足够的图像帧,添加到样本列表中
                    if len(sample['imgs']) == self.n_frames:
                        samples.append(sample)
        return samples

后来我去看有的论文其做法只是在kitti2015的多视角进行训练,其中把9~12帧去掉,将这些有标注的数据做为验证集,这里我有个疑惑就是,对于2015行了,直接使用多视角训练后进行验证,但是对于2012呢?难道不再使用2012的多视角进行训练,然后在2012的验证集上验证吗?这现在就是我的困惑?

我悟了,感谢灵感,刚才我有个念头,就是让我去看一下好了,结果我对比了一下,发现是一样的,哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈,那就证明了我的猜想,单用kitti2015进行训练,然后再使用这个模型直接对2012和2015的论文进行验证。
在这里插入图片描述
ARFlow的作者,你为什么这么贴心呢?这么有爱呢?生怕别人不理解,你就好像一座黑暗中的灯塔照亮抹黑前行的人们,我如果发出了论文,我一定会感谢你的,我也会使用的你的论文进行改进,谢谢你们!

刚才我又重新看了ARFlow的文档,发现作者也是贴心的附上了sintel的raw数据集,我真的是,其实对于无监督学习来说,应该发挥自己的优势,就是可以使用更多的没有标注的数据,这就是优势嘛!

这里最后给验证数据dataset进行注释一下:

class KITTIFlow(ImgSeqDataset):
    """
    该数据集仅用于验证,因此关于目标的所有文件都以文件路径形式存储,
    并且目标的转换没有进行任何处理。
    """
    def __init__(self, root, n_frames=2, transform=None):
        super(KITTIFlow, self).__init__(root, n_frames, input_transform=transform)

    def __getitem__(self, idx):
        # 获取样本信息
        s = self.samples[idx]

        # 对于2帧,img 1和2用于构成图像序列;对于3帧,img 0、1和2用于构成图像序列。
        st = 1 if self.n_frames == 2 else 0
        ed = st + self.n_frames
        imgs = [s['img{}'.format(i)] for i in range(st, ed)]

        # 读取图像数据并获取原始大小
        inputs = [imageio.imread(self.root / p).astype(np.float32) for p in imgs]
        raw_size = inputs[0].shape[:2]

        # 构建数据字典
        data = {
            'flow_occ': self.root / s['flow_occ'],
            'flow_noc': self.root / s['flow_noc'],
        }

        # 为测试集添加额外信息
        data.update({
            'im_shape': raw_size,
            'img1_path': self.root / s['img1'],
        })

        # 如果存在输入变换,对输入图像进行处理
        if self.input_transform is not None:
            inputs = [self.input_transform(i) for i in inputs]
        data.update({'img{}'.format(i + 1): inputs[i] for i in range(self.n_frames)})

        return data

    def collect_samples(self):
        '''在训练文件夹中搜索包含 'flow_noc' 或 'flow_occ'
           以及 'colored_0'(KITTI 2012)或 'image_2'(KITTI 2015)的文件夹'''

        # 设置光流图路径
        flow_occ_dir = 'flow_' + 'occ'
        flow_noc_dir = 'flow_' + 'noc'
        assert (self.root / flow_occ_dir).isdir()

        # 设置图像路径
        img_dir = 'image_2'
        assert (self.root / img_dir).isdir()

        samples = []
        # 遍历光流图目录下所有以 '.png' 结尾的文件
        for flow_map in sorted((self.root / flow_occ_dir).glob('*.png')):
            flow_map = flow_map.basename()
            root_filename = flow_map[:-7]

            # 构建光流图文件路径
            flow_occ_map = flow_occ_dir + '/' + flow_map
            flow_noc_map = flow_noc_dir + '/' + flow_map
            s = {'flow_occ': flow_occ_map, 'flow_noc': flow_noc_map}

            # 构建图像文件路径
            img1 = img_dir + '/' + root_filename + '_10.png'
            img2 = img_dir + '/' + root_filename + '_11.png'
            assert (self.root / img1).isfile() and (self.root / img2).isfile()
            s.update({'img1': img1, 'img2': img2})

            # 对于3帧,添加额外的图像文件路径
            if self.n_frames == 3:
                img0 = img_dir + '/' + root_filename + '_09.png'
                assert (self.root / img0).isfile()
                s.update({'img0': img0})

            samples.append(s)
        return samples

这里面也有个细节,就是在使用验证的时候,只是使用了image2这个视角的图像,并没有使用image3。

sintel_raw

因为我们获取到了这个数据,现在还是弄一下这个代码吧:

class SintelRaw(ImgSeqDataset):
    def __init__(self, root, n_frames=2, transform=None, co_transform=None):
        super(SintelRaw, self).__init__(root, n_frames, input_transform=transform,
                                        co_transform=co_transform)

    def collect_samples(self):
        scene_list = self.root.dirs()
        samples = []
        for scene in scene_list:
            img_list = scene.files('*.png')
            img_list.sort()

            for st in range(0, len(img_list) - self.n_frames + 1):
                seq = img_list[st:st + self.n_frames]
                sample = {'imgs': [self.root.relpathto(file) for file in seq]}
                samples.append(sample)
        return samples

在这里插入图片描述

接口文件

我们接下来看一下接口的文件,就是将我们前面的代码集成一下,先看下要引入的包

import copy
from torchvision import transforms

from torch.utils.data import ConcatDataset
from transforms.co_transforms import get_co_transforms
from transforms.ar_transforms.ap_transforms import get_ap_transforms
from transforms import sep_transforms
from datasets.flow_datasets import SintelRaw, Sintel
from datasets.flow_datasets import KITTIRawFile, KITTIFlow, KITTIFlowMV

其实当自己低频的时候,平日里学到的绝大多数知识都用不上,唯有让自己的大脑安静下来,才会获得力量,所以静定慧的静真的就是万道之门。

对应的包

上面的代码中torch.utils.data.ConcatDataset 是 PyTorch 中用于将多个数据集合并成一个的工具类。它接受一个包含多个数据集的列表,并将它们串联在一起,以便在训练和验证中更方便地处理多个数据来源。

并且有很多变换的操作,我们先看一下sep_transforms:

import numpy as np
import torch
# from scipy.misc import imresize
from skimage.transform import resize as imresize

class ArrayToTensor(object):
    """将numpy.ndarray(H x W x C)转换为torch.FloatTensor(C x H x W)。"""

    def __call__(self, array):
        assert (isinstance(array, np.ndarray))
        # 将数组的轴顺序从HWC(Height x Width x Channels)转换为CHW(Channels x Height x Width)
        array = np.transpose(array, (2, 0, 1))
        # 转换为torch张量
        tensor = torch.from_numpy(array)
        # 将数据类型转换为float
        return tensor.float()

class Zoom(object):
    def __init__(self, new_h, new_w):
        self.new_h = new_h
        self.new_w = new_w

    def __call__(self, image):
        h, w, _ = image.shape
        # 如果图像已经是目标尺寸,则直接返回
        if h == self.new_h and w == self.new_w:
            return image
        # 使用imresize函数调整图像大小
        image = imresize(image, (self.new_h, self.new_w))
        return image

具体来说,对于一个类的实例 obj,如果该类定义了 call 方法,那么可以使用 obj() 的形式调用这个实例,就像调用一个函数一样

get_co_transforms 函数:用于获取数据增强的组合操作,根据输入的数据增强参数创建对应的操作列表。
Compose 类:将多个数据增强操作组合在一起的类,实例可以像函数一样被调用,对输入图像和目标图像进行组合操作。
RandomCrop 类:在随机位置对给定图像进行裁剪,使其具有给定尺寸的区域。
RandomSwap 类:以一定概率随机交换输入图像列表的顺序。
RandomHorizontalFlip 类:以0.5的概率随机水平翻转给定图像。

import numbers
import random
import numpy as np
# from scipy.misc import imresize
from skimage.transform import resize as imresize
import scipy.ndimage as ndimage

def get_co_transforms(aug_args):
    """
    获取数据增强的组合操作。
    :param aug_args: 包含数据增强参数的对象。
    :return: 数据增强的组合操作。
    """
    transforms = []
    if aug_args.crop:
        transforms.append(RandomCrop(aug_args.para_crop))
    if aug_args.hflip:
        transforms.append(RandomHorizontalFlip())
    if aug_args.swap:
        transforms.append(RandomSwap())
    return Compose(transforms)

class Compose(object):
    def __init__(self, co_transforms):
        """
        将多个数据增强操作组合在一起的类。
        :param co_transforms: 数据增强操作的列表。
        """
        self.co_transforms = co_transforms

    def __call__(self, input, target):
        """
        对输入图像和目标进行组合的调用方法。
        :param input: 输入图像。
        :param target: 目标图像。
        :return: 组合后的输入和目标。
        """
        for t in self.co_transforms:
            input, target = t(input, target)
        return input, target

class RandomCrop(object):
    """在随机位置对给定图像进行裁剪,使其具有给定尺寸的区域。
    尺寸可以是一个元组(目标高度,目标宽度)或一个整数,此时目标将是一个正方形(size, size)。
    """
    def __init__(self, size):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size

    def __call__(self, inputs, target):
        h, w, _ = inputs[0].shape
        th, tw = self.size
        if w == tw and h == th:
            return inputs, target

        x1 = random.randint(0, w - tw)
        y1 = random.randint(0, h - th)
        inputs = [img[y1: y1 + th, x1: x1 + tw] for img in inputs]
        if 'mask' in target:
            target['mask'] = target['mask'][y1: y1 + th, x1: x1 + tw]
        if 'flow' in target:
            target['flow'] = target['flow'][y1: y1 + th, x1: x1 + tw]
        return inputs, target

class RandomSwap(object):
    '''RandomSwap类:以一定概率随机交换输入图像列表的顺序。'''
    def __call__(self, inputs, target):
        n = len(inputs)
        if random.random() < 0.5:
            inputs = inputs[::-1]
            if 'mask' in target:
                target['mask'] = target['mask'][::-1]
            if 'flow' in target:
                raise NotImplementedError("swap cannot apply to flow")
        return inputs, target

class RandomHorizontalFlip(object):
    """以0.5的概率随机水平翻转给定图像。"""
    def __call__(self, inputs, target):
        if random.random() < 0.5:
            inputs = [np.copy(np.fliplr(im)) for im in inputs]
            if 'mask' in target:
                target['mask'] = [np.copy(np.fliplr(mask)) for mask in target['mask']]
            if 'flow' in target:
                for i, flo in enumerate(target['flow']):
                    flo = np.copy(np.fliplr(flo))
                    flo[:, :, 0] *= -1
                    target['flow'][i] = flo
        return inputs, target

这些代码我们可以直接复用。最后我们看下最后的文件夹ar变换,这个好像得去看下论文呀。我们这里先把数据dataset弄完

import copy
from torchvision import transforms

from torch.utils.data import ConcatDataset
from transforms.co_transforms import get_co_transforms
from transforms.ar_transforms.ap_transforms import get_ap_transforms
from transforms import sep_transforms
from datasets.flow_datasets import SintelRaw, Sintel
from datasets.flow_datasets import KITTIFlow, KITTIFlowMV
from datasets.flow_datasets import Flychairs

def get_dataset(all_cfg):
    cfg = all_cfg.data
    # 输入数据变换
    input_transform = transforms.Compose([
        sep_transforms.ArrayToTensor(),
        transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]),
    ])
    # 获取协同变换操作
    co_transform = get_co_transforms(aug_args=all_cfg.data_aug)

    # Sintel_Flow训练
    if cfg.type == 'Sintel_Flow':
        ap_transform = get_ap_transforms(cfg.at_cfg) if cfg.run_at else None

        train_set_1 = Sintel(cfg.root_sintel, n_frames=cfg.train_n_frames, type='clean',
                             split='training', subsplit=cfg.train_subsplit,
                             with_flow=False,
                             ap_transform=ap_transform,
                             transform=input_transform,
                             co_transform=co_transform
                             )
        train_set_2 = Sintel(cfg.root_sintel, n_frames=cfg.train_n_frames, type='final',
                             split='training', subsplit=cfg.train_subsplit,
                             with_flow=False,
                             ap_transform=ap_transform,
                             transform=input_transform,
                             co_transform=co_transform
                             )
        # 组合训练数据dataset
        train_set = ConcatDataset([train_set_1, train_set_2])

        # 深度拷贝
        valid_input_transform = copy.deepcopy(input_transform)
        # 将新创建的缩放变换插入到验证数据变换的第一个位置。
        # 这是因为在验证阶段,通常希望对输入图像进行一些标准的预处理,例如调整尺寸,以适应模型的输入要求。
        valid_input_transform.transforms.insert(0, sep_transforms.Zoom(*cfg.test_shape))
        valid_set_1 = Sintel(cfg.root_sintel, n_frames=cfg.val_n_frames, type='clean',
                             split='training', subsplit=cfg.val_subsplit,
                             transform=valid_input_transform,
                             target_transform={'flow': sep_transforms.ArrayToTensor()}
                             )
        valid_set_2 = Sintel(cfg.root_sintel, n_frames=cfg.val_n_frames, type='final',
                             split='training', subsplit=cfg.val_subsplit,
                             transform=valid_input_transform,
                             target_transform={'flow': sep_transforms.ArrayToTensor()}
                             )
        # 组合验证数据dataset
        valid_set = ConcatDataset([valid_set_1, valid_set_2])

    # Sintel_Raw训练
    elif cfg.type == 'Sintel_Raw':
        train_set = SintelRaw(cfg.root_sintel_raw, n_frames=cfg.train_n_frames,
                              transform=input_transform, co_transform=co_transform)
        valid_input_transform = copy.deepcopy(input_transform)
        valid_input_transform.transforms.insert(0, sep_transforms.Zoom(*cfg.test_shape))

        valid_set_1 = Sintel(cfg.root_sintel, n_frames=cfg.val_n_frames, type='clean',
                             split='training', subsplit=cfg.val_subsplit,
                             transform=valid_input_transform,
                             target_transform={'flow': sep_transforms.ArrayToTensor()}
                             )
        valid_set_2 = Sintel(cfg.root_sintel, n_frames=cfg.val_n_frames, type='final',
                             split='training', subsplit=cfg.val_subsplit,
                             transform=valid_input_transform,
                             target_transform={'flow': sep_transforms.ArrayToTensor()}
                             )
        valid_set = ConcatDataset([valid_set_1, valid_set_2])

    # Flycharis上进行预训练
    elif cfg.type == 'Flycharis':
        train_input_transform = copy.deepcopy(input_transform)
        train_input_transform.transforms.insert(0, sep_transforms.Zoom(*cfg.train_shape))

        ap_transform = get_ap_transforms(cfg.at_cfg) if cfg.run_at else None
        train_set = Flychairs(
            cfg.root,
            cfg.train_n_frames,
            split='training',
            transform=train_input_transform,
            ap_transform=ap_transform,
            co_transform=co_transform  # no target here
        )

        valid_input_transform = copy.deepcopy(input_transform)
        valid_input_transform.transforms.insert(0, sep_transforms.Zoom(*cfg.test_shape))
        valid_set = Flychairs(
            cfg.root,
            cfg.train_n_frames,
            split='validation',
            transform=train_input_transform,
            ap_transform=ap_transform,
            co_transform=co_transform  # no target here
        )

    # 在多视角KITTI上进行训练
    elif cfg.type == 'KITTI_MV':
        train_input_transform = copy.deepcopy(input_transform)
        train_input_transform.transforms.insert(0, sep_transforms.Zoom(*cfg.train_shape))

        root_flow = cfg.root_kitti15 if cfg.train_15 else cfg.root_kitti12

        ap_transform = get_ap_transforms(cfg.at_cfg) if cfg.run_at else None
        train_set = KITTIFlowMV(
            root_flow,
            cfg.train_n_frames,
            transform=train_input_transform,
            ap_transform=ap_transform,
            co_transform=co_transform  # no target here
        )
        # 深度拷贝
        valid_input_transform = copy.deepcopy(input_transform)
        # 将新创建的缩放变换插入到验证数据变换的第一个位置。
        # 这是因为在验证阶段,通常希望对输入图像进行一些标准的预处理,例如调整尺寸,以适应模型的输入要求。
        valid_input_transform.transforms.insert(0, sep_transforms.Zoom(*cfg.test_shape))
        # 验证数据dataset构建
        valid_set_1 = KITTIFlow(cfg.root_kitti15, n_frames=cfg.val_n_frames,
                                transform=valid_input_transform,
                                )
        valid_set_2 = KITTIFlow(cfg.root_kitti12, n_frames=cfg.val_n_frames,
                                transform=valid_input_transform,
                                )
        valid_set = ConcatDataset([valid_set_1, valid_set_2])
    else:
        raise NotImplementedError(cfg.type)
    return train_set, valid_set

浅拷贝创建一个新对象,然后将原始对象中的元素复制到新对象中。但是,如果原始对象包含嵌套对象(例如列表中包含了另一个列表),则浅拷贝只复制了嵌套对象的引用,而不是创建嵌套对象的新实例。

深度拷贝创建一个新对象,并递归地复制原始对象及其所有嵌套对象的副本。这意味着,即使原始对象包含嵌套对象,深度拷贝也会创建嵌套对象的新实例,而不仅仅是复制引用。

  • 13
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值