前言
宝宝心里苦呀,调了几天的官方的tensorflow代码怎么都弄不通,后来一想,本科的毕业设计好像没有精度要求吧,我不如自己去搭建一个pipline,这样以后我也可以在这个基础上进行改进嘛。并且最理想的就是可以去copy下大佬写的代码嘛,但是很难受,很多都跑不通,终于我找到一个可以跑通的工程了:ARFlow,很感谢这个作者,让我在黑暗中看到了光,我的世界没有别人,你是未来的我给我的帮助嘛,感谢引领!
Here we go!
数据dataset基类
我们在构建不同数据dataset的时候,要先去定义一个基类,之后面对不同的数据集定义不同的dataset类,我们这里借助ARFlow的代码进行,下面是注释后的代码。
import imageio
import numpy as np
import random
from path import Path
from abc import abstractmethod, ABCMeta
from torch.utils.data import Dataset
from utils.flow_utils import load_flow
class ImgSeqDataset(Dataset, metaclass=ABCMeta):
def __init__(self, root, n_frames, input_transform=None, co_transform=None,
target_transform=None, ap_transform=None):
# 数据集路径
self.root = Path(root)
# 图像序列帧数,现在很多论文开始使用3帧了
self.n_frames = n_frames
# 对输入帧进行变换
self.input_transform = input_transform
# 协同变换
self.co_transform = co_transform
# 姿态变换
self.ap_transform = ap_transform
# 对目标进行变换,这里指光流和mask
self.target_transform = target_transform
# 通过collect_samples方法得到图像
self.samples = self.collect_samples()
@abstractmethod
def collect_samples(self):
# 抽象方法,子类需要实现用于收集样本信息的方法
pass
def _load_sample(self, s):
# 加载单个样本的方法
# 从样本字典中获取图像序列路径列表
images = s['imgs']
# 读取图像序列中的每一帧图像,self.root / p 表示构建完整的数据路径
# 将每一帧图像加载为 NumPy 数组,并将数据类型转换为 float32
images = [imageio.imread(self.root / p).astype(np.float32) for p in images]
# 初始化目标字典
target = {}
# 如果样本中包含光流信息,则加载光流数据
if 'flow' in s:
target['flow'] = load_flow(self.root / s['flow'])
# 如果样本中包含掩码信息,则加载并处理掩码数据
if 'mask' in s:
mask = imageio.imread(self.root / s['mask']).astype(np.float32) / 255.
# 如果掩码数据是三维的,只选择第一个通道
if len(mask.shape) == 3:
mask = mask[:, :, 0]
# 将掩码数据的通道维度扩展,变成 (H, W, 1) 的形状
target['mask'] = np.expand_dims(mask, -1)
# 返回加载的图像序列列表和目标字典
return images, target
def __len__(self):
# 返回数据集中样本的数量
return len(self.samples)
def __getitem__(self, idx):
# 获取指定索引处的样本
images, target = self._load_sample(self.samples[idx])
if self.co_transform is not None:
# 如果定义了协同变换,则应用协同变换
images, _ = self.co_transform(images, {})
if self.input_transform is not None:
# 如果定义了输入变换,则应用输入变换
images = [self.input_transform(i) for i in images]
data = {'img{}'.format(i + 1): p for i, p in enumerate(images)}
if self.ap_transform is not None:
# 如果定义了姿势估计变换,则应用姿势估计变换
imgs_ph = self.ap_transform(
[data['img{}'.format(i + 1)].clone() for i in range(self.n_frames)])
for i in range(self.n_frames):
data['img{}_ph'.format(i + 1)] = imgs_ph[i]
if self.target_transform is not None:
# 如果定义了目标变换,则应用目标变换
for key in self.target_transform.keys():
target[key] = self.target_transform[key](target[key])
data['target'] = target
return data
这里我们看到定义一个图像对的dataset,然后大多数都在注释中,这里说点细节或者需要知道的一些知识:
from abc import abstractmethod, ABCMeta
先说下这个包的作用吧,在给定的代码中,from abc import abstractmethod, ABCMeta 引入了 Python 标准库中的 abc 模块,其中包括两个重要的元素:abstractmethod 和 ABCMeta。
abstractmethod 是一个装饰器(decorator),用于声明一个方法是抽象方法。抽象方法是在基类中声明但没有提供具体实现的方法,留待子类去实现。在抽象类中,至少包含一个抽象方法。子类必须实现所有抽象方法,否则无法被实例化。
ABCMeta 是一个元类,用于创建抽象基类(Abstract Base Class,ABC)。元类是一种类的类,它定义了类的行为,如何创建类的实例以及类与其实例的关系。ABCMeta 是 Python 中用于创建抽象基类的特殊元类。通过使用 ABCMeta 元类,类可以被定义为抽象基类,并且可以包含抽象方法。
在给定的代码中,这两个元素被用于定义一个抽象基类 ImgSeqDataset。通过在类定义时使用 metaclass=ABCMeta,指定了该类的元类为 ABCMeta,从而使 ImgSeqDataset 成为抽象基类。同时,通过使用 @abstractmethod 装饰器,声明了一个抽象方法 collect_samples,该方法在子类中必须被实现。
之后我们看这个加载数据的函数,在Python中,以单个下划线 _ 开头的命名约定通常表示该成员是一个"内部"成员,即它是类的内部使用或是模块的内部实现的一部分。这并不是强制规定,而是一种编码风格的约定,旨在向其他开发者传达该成员不应该被直接访问或使用,而是被认为是类的内部实现的一部分。
继承这个基类的抽象类必须要实现collect_samples方法,因为对于不同的数据集,格式是不一样的,必然使用的方法也不一样呀,这里注意一下,是返回一个样本字典。
定义dataset的时候,我们必须定义的有len和gettiem,前者这里不说了,这里主要说一下后者。
我们通过collect_samples可以获取到所有的样本字典,在getitem中,先通过加载样本获取到图像数据和目标字典,之后构建类似于 {‘img1’: image1, ‘img2’: image2, ‘img3’: image3} 的键值对。这样的数据结构可以用于在后续的处理中,更方便地引用和操作图像序列中的每一帧。之后在进行图像变换。
这里注意一下,先将图像读取进来,之后是以字典的形式保存。
其实论文中是先在raw数据,即原始数据,然后才在标准数据集上进行训练,因为是无监督嘛,所以自然可以使用全部的帧,但是这里有个问题呀,就是你可以使用全部的数据,自然也是可以包含了测试数据了,这有点不公平呀,所以大多数的文章好像都不是这么干的呀,这也不太现实呀,所以我这里只制作标准的数据集,sintel,kittil12/15,其实就是这两个有说服了,flychairs也算一个,其实每个文章的方法都不一样呀,对比的话我感觉不太公平,但是没办法就是最后的结果吗?
sintel数据dataset
原工程中有raw,我们这里舍弃了这个,后面会使用flychairs进行预训练,当然这个预训练也是无监督的。下面直接给出代码,注释的很清晰了:
class Sintel(ImgSeqDataset):
def __init__(self, root, n_frames=2, type='clean', split='training',
subsplit='trainval', with_flow=True, ap_transform=None,
transform=None, target_transform=None, co_transform=None):
"""
Sintel 数据集类,用于加载 Sintel 数据集的图像序列。
参数:
- root (str): 数据集根目录的路径。
- n_frames (int): 图像序列的帧数,默认为2。
- type (str): 数据集类型,'clean' 或 'final'。
- split (str): 数据集分割,'training' 或 'test'。
- subsplit (str): 数据集子分割,'trainval'、'train' 或 'val'。
- with_flow (bool): 是否包含光流信息,默认为 True。
- ap_transform (callable): 姿势估计变换函数,默认为 None。
- transform (callable): 输入变换函数,默认为 None。
- target_transform (callable): 目标变换函数,默认为 None。
- co_transform (callable): 协同变换函数,默认为 None。
"""
self.dataset_type = type
self.with_flow = with_flow
# split为training或者test
self.split = split
self.subsplit = subsplit
self.training_scene = ['alley_1', 'ambush_4', 'ambush_6', 'ambush_7', 'bamboo_2',
'bandage_2', 'cave_2', 'market_2', 'market_5', 'shaman_2',
'sleeping_2', 'temple_3'] # 非官方的训练-验证拆分
# 构造数据集根路径
root = Path(root) / split
super(Sintel, self).__init__(root, n_frames, input_transform=transform,
target_transform=target_transform,
co_transform=co_transform, ap_transform=ap_transform)
# 必须要去写的方法
def collect_samples(self):
'''收集样本信息的方法,返回一个包含样本信息的列表。'''
# 图像文件夹
img_dir = self.root / Path(self.dataset_type)
# 光流文件夹
flow_dir = self.root / 'flow'
# 确保图像目录和光流目录存在
assert img_dir.isdir() and flow_dir.isdir()
samples = []
# 结果是一个按文件路径名排序的光流场文件列表
for flow_map in sorted((self.root / flow_dir).glob('*/*.flo')):
# 用于将文件路径分解为场景(scene)和文件名(filename)等信息
info = flow_map.splitall()
scene, filename = info[-2:]
fid = int(filename[-8:-4])
# 根据分割和子分割规则选择样本
if self.split == 'training' and self.subsplit != 'trainval':
if self.subsplit == 'train' and scene not in self.training_scene:
continue
if self.subsplit == 'val' and scene in self.training_scene:
continue
# 构建样本字典,包含图像序列的文件路径列表,这里是一个序列
s = {'imgs': [img_dir / scene / 'frame_{:04d}.png'.format(fid + i) for i in
range(self.n_frames)]}
try:
# 检查所有帧图像文件是否存在
assert all([p.isfile() for p in s['imgs']])
# 如果需要光流信息,则添加到样本字典中
if self.with_flow:
if self.n_frames == 3:
# 对于 img1 img2 img3,只评估 flow_23
s['flow'] = flow_dir / scene / 'frame_{:04d}.flo'.format(fid + 1)
elif self.n_frames == 2:
# 对于 img1 img2,评估 flow_12
s['flow'] = flow_dir / scene / 'frame_{:04d}.flo'.format(fid)
else:
raise NotImplementedError(
'n_frames {} with flow or mask'.format(self.n_frames))
# 检查光流文件是否存在
if self.with_flow:
assert s['flow'].isfile()
except AssertionError:
print('Incomplete sample for: {}'.format(s['imgs'][0]))
continue
# 将当前样本添加到样本列表中
samples.append(s)
# 返回所有样本的列表
return samples
这里面的training_scene要注意一下,如果我们选择trainval是不用管这个参数的,我也试了一下,发现所有的training数据都会用到的,然后下面是我的一个测试的代码吧:
from torch.utils.data import DataLoader
from flow_datasets import Sintel
# 假设你的 Sintel 数据集位于'/path/to/sintel'目录下
sintel_dataset = Sintel(root='D:/AI_Do/data/MPI-Sintel', n_frames=2, type='clean', split='training', subsplit='trainval', with_flow=True)
# 实例化 DataLoader,用于批量加载数据
data_loader = DataLoader(sintel_dataset, batch_size=1, shuffle=True, num_workers=0)
# 迭代数据集中的批次
for batch in data_loader:
print(batch)
break
flychairs数据dataset
kitti2012/2015我还没有下完,就先把这个弄了吧,ARFlow中并没有这个代码,这里我们尝试借鉴RAFT的代码写一下:
class Flychairs(ImgSeqDataset):
def __init__(self, root, n_frames=2, subsplit='trainval',split = 'training', split_file='FlyingChairs_train_val.txt',with_flow=True, ap_transform=None,
transform=None, target_transform=None, co_transform=None):
"""
Sintel 数据集类,用于加载 Sintel 数据集的图像序列。
参数:
- root (str): 数据集根目录的路径:D:\AI_Do\data\FlyingChairs_release
- n_frames (int): 图像序列的帧数,默认为2。
- split(str): 分割形式,training 或 validation。
- split_file (str): 分割文件。
- subsplit (str): 数据集子分割,'trainval'、'train' 或 'val'。
- with_flow (bool): 是否包含光流信息,默认为 True。
- ap_transform (callable): 姿势估计变换函数,默认为 None。
- transform (callable): 输入变换函数,默认为 None。
- target_transform (callable): 目标变换函数,默认为 None。
- co_transform (callable): 协同变换函数,默认为 None。
"""
self.with_flow = with_flow
self.subsplit = subsplit
self.split = split
self.root = root
# 构造分割文件路径
self.separate_file = os.path.join(root, split_file)
# 构建数据文件夹路径
self.data_dir = os.path.join(root, 'data')
# 确保数据文件夹存在
assert os.path.isdir(self.data_dir)
super(Flychairs, self).__init__(root, n_frames, input_transform=transform,
target_transform=target_transform,
co_transform=co_transform, ap_transform=ap_transform)
def collect_samples(self):
samples = []
# 获取所有的ppm(图像)和flo(光流图)文件
images = sorted(glob(osp.join(self.data_dir, '*.ppm')))
flows = sorted(glob(osp.join(self.data_dir, '*.flo')))
assert (len(images) // 2 == len(flows))
# 打开分割文件,全是数字1或者2,1代表训练,2代表验证
split_list = np.loadtxt(self.separate_file, dtype=np.int32)
for i in range(len(flows)):
xid = split_list[i]
# 根据split的不同构建不同的数据集
if (self.split == 'training' and xid == 1) or (self.split == 'validation' and xid == 2):
# 构建样本字典,包含图像序列的文件路径列表,这里是一个序列
s = {'imgs': [images[2 * i], images[2 * i + 1]]}
try:
# 检查所有帧图像文件是否存在
assert all([os.path.isfile(p) for p in s['imgs']])
# 如果需要光流信息,则添加到样本字典中
if self.with_flow:
if self.n_frames == 3:
# 对于 img1 img2 img3,只评估 flow_23
s['flow'] = flows[i+1]
elif self.n_frames == 2:
# 对于 img1 img2,评估 flow_12
s['flow'] = flows[i]
else:
raise NotImplementedError(
'n_frames {} with flow or mask'.format(self.n_frames))
# 检查光流文件是否存在
if self.with_flow:
assert os.path.isfile(s['flow'])
except AssertionError:
print('Incomplete sample for: {}'.format(s['imgs'][0]))
continue
# 将当前样本添加到样本列表中
samples.append(s)
# 返回所有样本的列表
return samples
注释写的很清晰了,然后说一个细节,就是我看了验证集的部分,对于flychairs和sintel是不一样的,前者是因为没有test集,所以就将一部分不放在训练里,而后者是有专门的test集的,所以就是training没有分一下,直接全部进行训练了,并且将final和clean都放在一起训练了,放在一起训练的数据是多的,我们也先这么做吧。
测试代码与上面的差不多,这里就不放上来了。
kitti数据dataset
这个数据集下的很慢,但是现在也是下完了,现在开始吧,这里先说一下,在ARFlow中是将kitti所有的原始数据都下载下来了,我们这里并没有,而是使用的是12/15的多视角的数据集,就是两个10多G的数据集,在ARFlow中作者说直接使用这个训练是差不多的,我们这里先来查看这个数据集。下载完毕后呈现的就是现在的情况
其中的multiview的话,全是图像,没有标注数据,我们会使用这个进行无监督训练的,如在第二个文件夹下的image_2的文件夹下有4200张图片。而第一个和第三个是有标注文件的,但数据会少很多,如在第一个文件夹下的image_2的文件夹仅有400张图像。那么我们现在开始学习代码吧!并且我们这里提一下image_0与image_1是灰度的图像,所以我们这里主要使用image_2和image_3。
在ARFlow的工程中定义了三个dataset的类,分别是raw数据,无监督训练,只验证三个dataset,我们这里因为不使用raw数据,所以只考虑后面两个。
在看代码之前,我们先把标签数据复制一下过来,但是我们训练的时候跳过9~12帧,为了后面使用这些数据进行验证的时候用的,因为测试集并没有公开标注,这也是现在主流的做法。
class KITTIFlowMV(ImgSeqDataset):
"""
这个数据集仅用于无监督训练
"""
def __init__(self, root, n_frames=2,
transform=None, co_transform=None, ap_transform=None, ):
super(KITTIFlowMV, self).__init__(root, n_frames,
input_transform=transform,
co_transform=co_transform,
ap_transform=ap_transform)
def collect_samples(self):
# 设置光流图路径
flow_occ_dir = 'flow_' + 'occ'
assert (self.root / flow_occ_dir).isdir()
# 设置左右图像路径
img_l_dir, img_r_dir = 'image_2', 'image_3'
assert (self.root / img_l_dir).isdir() and (self.root / img_r_dir).isdir()
samples = []
# 遍历光流图目录下所有以 '.png' 结尾的文件
for flow_map in sorted((self.root / flow_occ_dir).glob('*.png')):
flow_map = flow_map.basename()
root_filename = flow_map[:-7]
for img_dir in [img_l_dir, img_r_dir]:
# 获取与光流图文件名相关的图像文件列表
img_list = (self.root / img_dir).files('*{}*.png'.format(root_filename))
img_list.sort()
for st in range(0, len(img_list) - self.n_frames + 1):
# 获取图像序列
seq = img_list[st:st + self.n_frames]
sample = {}
sample['imgs'] = []
for i, file in enumerate(seq):
# 解析文件名中的帧号
frame_id = int(file[-6:-4])
# 如果帧号在 9 和 12 之间(包括 9 和 12),终止循环
if 12 >= frame_id >= 9:
break
sample['imgs'].append(self.root.relpathto(file))
# 如果采样到了足够的图像帧,添加到样本列表中
if len(sample['imgs']) == self.n_frames:
samples.append(sample)
return samples
后来我去看有的论文其做法只是在kitti2015的多视角进行训练,其中把9~12帧去掉,将这些有标注的数据做为验证集,这里我有个疑惑就是,对于2015行了,直接使用多视角训练后进行验证,但是对于2012呢?难道不再使用2012的多视角进行训练,然后在2012的验证集上验证吗?这现在就是我的困惑?
我悟了,感谢灵感,刚才我有个念头,就是让我去看一下好了,结果我对比了一下,发现是一样的,哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈,那就证明了我的猜想,单用kitti2015进行训练,然后再使用这个模型直接对2012和2015的论文进行验证。
ARFlow的作者,你为什么这么贴心呢?这么有爱呢?生怕别人不理解,你就好像一座黑暗中的灯塔照亮抹黑前行的人们,我如果发出了论文,我一定会感谢你的,我也会使用的你的论文进行改进,谢谢你们!
刚才我又重新看了ARFlow的文档,发现作者也是贴心的附上了sintel的raw数据集,我真的是,其实对于无监督学习来说,应该发挥自己的优势,就是可以使用更多的没有标注的数据,这就是优势嘛!
这里最后给验证数据dataset进行注释一下:
class KITTIFlow(ImgSeqDataset):
"""
该数据集仅用于验证,因此关于目标的所有文件都以文件路径形式存储,
并且目标的转换没有进行任何处理。
"""
def __init__(self, root, n_frames=2, transform=None):
super(KITTIFlow, self).__init__(root, n_frames, input_transform=transform)
def __getitem__(self, idx):
# 获取样本信息
s = self.samples[idx]
# 对于2帧,img 1和2用于构成图像序列;对于3帧,img 0、1和2用于构成图像序列。
st = 1 if self.n_frames == 2 else 0
ed = st + self.n_frames
imgs = [s['img{}'.format(i)] for i in range(st, ed)]
# 读取图像数据并获取原始大小
inputs = [imageio.imread(self.root / p).astype(np.float32) for p in imgs]
raw_size = inputs[0].shape[:2]
# 构建数据字典
data = {
'flow_occ': self.root / s['flow_occ'],
'flow_noc': self.root / s['flow_noc'],
}
# 为测试集添加额外信息
data.update({
'im_shape': raw_size,
'img1_path': self.root / s['img1'],
})
# 如果存在输入变换,对输入图像进行处理
if self.input_transform is not None:
inputs = [self.input_transform(i) for i in inputs]
data.update({'img{}'.format(i + 1): inputs[i] for i in range(self.n_frames)})
return data
def collect_samples(self):
'''在训练文件夹中搜索包含 'flow_noc' 或 'flow_occ'
以及 'colored_0'(KITTI 2012)或 'image_2'(KITTI 2015)的文件夹'''
# 设置光流图路径
flow_occ_dir = 'flow_' + 'occ'
flow_noc_dir = 'flow_' + 'noc'
assert (self.root / flow_occ_dir).isdir()
# 设置图像路径
img_dir = 'image_2'
assert (self.root / img_dir).isdir()
samples = []
# 遍历光流图目录下所有以 '.png' 结尾的文件
for flow_map in sorted((self.root / flow_occ_dir).glob('*.png')):
flow_map = flow_map.basename()
root_filename = flow_map[:-7]
# 构建光流图文件路径
flow_occ_map = flow_occ_dir + '/' + flow_map
flow_noc_map = flow_noc_dir + '/' + flow_map
s = {'flow_occ': flow_occ_map, 'flow_noc': flow_noc_map}
# 构建图像文件路径
img1 = img_dir + '/' + root_filename + '_10.png'
img2 = img_dir + '/' + root_filename + '_11.png'
assert (self.root / img1).isfile() and (self.root / img2).isfile()
s.update({'img1': img1, 'img2': img2})
# 对于3帧,添加额外的图像文件路径
if self.n_frames == 3:
img0 = img_dir + '/' + root_filename + '_09.png'
assert (self.root / img0).isfile()
s.update({'img0': img0})
samples.append(s)
return samples
这里面也有个细节,就是在使用验证的时候,只是使用了image2这个视角的图像,并没有使用image3。
sintel_raw
因为我们获取到了这个数据,现在还是弄一下这个代码吧:
class SintelRaw(ImgSeqDataset):
def __init__(self, root, n_frames=2, transform=None, co_transform=None):
super(SintelRaw, self).__init__(root, n_frames, input_transform=transform,
co_transform=co_transform)
def collect_samples(self):
scene_list = self.root.dirs()
samples = []
for scene in scene_list:
img_list = scene.files('*.png')
img_list.sort()
for st in range(0, len(img_list) - self.n_frames + 1):
seq = img_list[st:st + self.n_frames]
sample = {'imgs': [self.root.relpathto(file) for file in seq]}
samples.append(sample)
return samples
接口文件
我们接下来看一下接口的文件,就是将我们前面的代码集成一下,先看下要引入的包
import copy
from torchvision import transforms
from torch.utils.data import ConcatDataset
from transforms.co_transforms import get_co_transforms
from transforms.ar_transforms.ap_transforms import get_ap_transforms
from transforms import sep_transforms
from datasets.flow_datasets import SintelRaw, Sintel
from datasets.flow_datasets import KITTIRawFile, KITTIFlow, KITTIFlowMV
其实当自己低频的时候,平日里学到的绝大多数知识都用不上,唯有让自己的大脑安静下来,才会获得力量,所以静定慧的静真的就是万道之门。
对应的包
上面的代码中torch.utils.data.ConcatDataset 是 PyTorch 中用于将多个数据集合并成一个的工具类。它接受一个包含多个数据集的列表,并将它们串联在一起,以便在训练和验证中更方便地处理多个数据来源。
并且有很多变换的操作,我们先看一下sep_transforms:
import numpy as np
import torch
# from scipy.misc import imresize
from skimage.transform import resize as imresize
class ArrayToTensor(object):
"""将numpy.ndarray(H x W x C)转换为torch.FloatTensor(C x H x W)。"""
def __call__(self, array):
assert (isinstance(array, np.ndarray))
# 将数组的轴顺序从HWC(Height x Width x Channels)转换为CHW(Channels x Height x Width)
array = np.transpose(array, (2, 0, 1))
# 转换为torch张量
tensor = torch.from_numpy(array)
# 将数据类型转换为float
return tensor.float()
class Zoom(object):
def __init__(self, new_h, new_w):
self.new_h = new_h
self.new_w = new_w
def __call__(self, image):
h, w, _ = image.shape
# 如果图像已经是目标尺寸,则直接返回
if h == self.new_h and w == self.new_w:
return image
# 使用imresize函数调整图像大小
image = imresize(image, (self.new_h, self.new_w))
return image
具体来说,对于一个类的实例 obj,如果该类定义了 call 方法,那么可以使用 obj() 的形式调用这个实例,就像调用一个函数一样
get_co_transforms 函数:用于获取数据增强的组合操作,根据输入的数据增强参数创建对应的操作列表。
Compose 类:将多个数据增强操作组合在一起的类,实例可以像函数一样被调用,对输入图像和目标图像进行组合操作。
RandomCrop 类:在随机位置对给定图像进行裁剪,使其具有给定尺寸的区域。
RandomSwap 类:以一定概率随机交换输入图像列表的顺序。
RandomHorizontalFlip 类:以0.5的概率随机水平翻转给定图像。
import numbers
import random
import numpy as np
# from scipy.misc import imresize
from skimage.transform import resize as imresize
import scipy.ndimage as ndimage
def get_co_transforms(aug_args):
"""
获取数据增强的组合操作。
:param aug_args: 包含数据增强参数的对象。
:return: 数据增强的组合操作。
"""
transforms = []
if aug_args.crop:
transforms.append(RandomCrop(aug_args.para_crop))
if aug_args.hflip:
transforms.append(RandomHorizontalFlip())
if aug_args.swap:
transforms.append(RandomSwap())
return Compose(transforms)
class Compose(object):
def __init__(self, co_transforms):
"""
将多个数据增强操作组合在一起的类。
:param co_transforms: 数据增强操作的列表。
"""
self.co_transforms = co_transforms
def __call__(self, input, target):
"""
对输入图像和目标进行组合的调用方法。
:param input: 输入图像。
:param target: 目标图像。
:return: 组合后的输入和目标。
"""
for t in self.co_transforms:
input, target = t(input, target)
return input, target
class RandomCrop(object):
"""在随机位置对给定图像进行裁剪,使其具有给定尺寸的区域。
尺寸可以是一个元组(目标高度,目标宽度)或一个整数,此时目标将是一个正方形(size, size)。
"""
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, inputs, target):
h, w, _ = inputs[0].shape
th, tw = self.size
if w == tw and h == th:
return inputs, target
x1 = random.randint(0, w - tw)
y1 = random.randint(0, h - th)
inputs = [img[y1: y1 + th, x1: x1 + tw] for img in inputs]
if 'mask' in target:
target['mask'] = target['mask'][y1: y1 + th, x1: x1 + tw]
if 'flow' in target:
target['flow'] = target['flow'][y1: y1 + th, x1: x1 + tw]
return inputs, target
class RandomSwap(object):
'''RandomSwap类:以一定概率随机交换输入图像列表的顺序。'''
def __call__(self, inputs, target):
n = len(inputs)
if random.random() < 0.5:
inputs = inputs[::-1]
if 'mask' in target:
target['mask'] = target['mask'][::-1]
if 'flow' in target:
raise NotImplementedError("swap cannot apply to flow")
return inputs, target
class RandomHorizontalFlip(object):
"""以0.5的概率随机水平翻转给定图像。"""
def __call__(self, inputs, target):
if random.random() < 0.5:
inputs = [np.copy(np.fliplr(im)) for im in inputs]
if 'mask' in target:
target['mask'] = [np.copy(np.fliplr(mask)) for mask in target['mask']]
if 'flow' in target:
for i, flo in enumerate(target['flow']):
flo = np.copy(np.fliplr(flo))
flo[:, :, 0] *= -1
target['flow'][i] = flo
return inputs, target
这些代码我们可以直接复用。最后我们看下最后的文件夹ar变换,这个好像得去看下论文呀。我们这里先把数据dataset弄完
import copy
from torchvision import transforms
from torch.utils.data import ConcatDataset
from transforms.co_transforms import get_co_transforms
from transforms.ar_transforms.ap_transforms import get_ap_transforms
from transforms import sep_transforms
from datasets.flow_datasets import SintelRaw, Sintel
from datasets.flow_datasets import KITTIFlow, KITTIFlowMV
from datasets.flow_datasets import Flychairs
def get_dataset(all_cfg):
cfg = all_cfg.data
# 输入数据变换
input_transform = transforms.Compose([
sep_transforms.ArrayToTensor(),
transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]),
])
# 获取协同变换操作
co_transform = get_co_transforms(aug_args=all_cfg.data_aug)
# Sintel_Flow训练
if cfg.type == 'Sintel_Flow':
ap_transform = get_ap_transforms(cfg.at_cfg) if cfg.run_at else None
train_set_1 = Sintel(cfg.root_sintel, n_frames=cfg.train_n_frames, type='clean',
split='training', subsplit=cfg.train_subsplit,
with_flow=False,
ap_transform=ap_transform,
transform=input_transform,
co_transform=co_transform
)
train_set_2 = Sintel(cfg.root_sintel, n_frames=cfg.train_n_frames, type='final',
split='training', subsplit=cfg.train_subsplit,
with_flow=False,
ap_transform=ap_transform,
transform=input_transform,
co_transform=co_transform
)
# 组合训练数据dataset
train_set = ConcatDataset([train_set_1, train_set_2])
# 深度拷贝
valid_input_transform = copy.deepcopy(input_transform)
# 将新创建的缩放变换插入到验证数据变换的第一个位置。
# 这是因为在验证阶段,通常希望对输入图像进行一些标准的预处理,例如调整尺寸,以适应模型的输入要求。
valid_input_transform.transforms.insert(0, sep_transforms.Zoom(*cfg.test_shape))
valid_set_1 = Sintel(cfg.root_sintel, n_frames=cfg.val_n_frames, type='clean',
split='training', subsplit=cfg.val_subsplit,
transform=valid_input_transform,
target_transform={'flow': sep_transforms.ArrayToTensor()}
)
valid_set_2 = Sintel(cfg.root_sintel, n_frames=cfg.val_n_frames, type='final',
split='training', subsplit=cfg.val_subsplit,
transform=valid_input_transform,
target_transform={'flow': sep_transforms.ArrayToTensor()}
)
# 组合验证数据dataset
valid_set = ConcatDataset([valid_set_1, valid_set_2])
# Sintel_Raw训练
elif cfg.type == 'Sintel_Raw':
train_set = SintelRaw(cfg.root_sintel_raw, n_frames=cfg.train_n_frames,
transform=input_transform, co_transform=co_transform)
valid_input_transform = copy.deepcopy(input_transform)
valid_input_transform.transforms.insert(0, sep_transforms.Zoom(*cfg.test_shape))
valid_set_1 = Sintel(cfg.root_sintel, n_frames=cfg.val_n_frames, type='clean',
split='training', subsplit=cfg.val_subsplit,
transform=valid_input_transform,
target_transform={'flow': sep_transforms.ArrayToTensor()}
)
valid_set_2 = Sintel(cfg.root_sintel, n_frames=cfg.val_n_frames, type='final',
split='training', subsplit=cfg.val_subsplit,
transform=valid_input_transform,
target_transform={'flow': sep_transforms.ArrayToTensor()}
)
valid_set = ConcatDataset([valid_set_1, valid_set_2])
# Flycharis上进行预训练
elif cfg.type == 'Flycharis':
train_input_transform = copy.deepcopy(input_transform)
train_input_transform.transforms.insert(0, sep_transforms.Zoom(*cfg.train_shape))
ap_transform = get_ap_transforms(cfg.at_cfg) if cfg.run_at else None
train_set = Flychairs(
cfg.root,
cfg.train_n_frames,
split='training',
transform=train_input_transform,
ap_transform=ap_transform,
co_transform=co_transform # no target here
)
valid_input_transform = copy.deepcopy(input_transform)
valid_input_transform.transforms.insert(0, sep_transforms.Zoom(*cfg.test_shape))
valid_set = Flychairs(
cfg.root,
cfg.train_n_frames,
split='validation',
transform=train_input_transform,
ap_transform=ap_transform,
co_transform=co_transform # no target here
)
# 在多视角KITTI上进行训练
elif cfg.type == 'KITTI_MV':
train_input_transform = copy.deepcopy(input_transform)
train_input_transform.transforms.insert(0, sep_transforms.Zoom(*cfg.train_shape))
root_flow = cfg.root_kitti15 if cfg.train_15 else cfg.root_kitti12
ap_transform = get_ap_transforms(cfg.at_cfg) if cfg.run_at else None
train_set = KITTIFlowMV(
root_flow,
cfg.train_n_frames,
transform=train_input_transform,
ap_transform=ap_transform,
co_transform=co_transform # no target here
)
# 深度拷贝
valid_input_transform = copy.deepcopy(input_transform)
# 将新创建的缩放变换插入到验证数据变换的第一个位置。
# 这是因为在验证阶段,通常希望对输入图像进行一些标准的预处理,例如调整尺寸,以适应模型的输入要求。
valid_input_transform.transforms.insert(0, sep_transforms.Zoom(*cfg.test_shape))
# 验证数据dataset构建
valid_set_1 = KITTIFlow(cfg.root_kitti15, n_frames=cfg.val_n_frames,
transform=valid_input_transform,
)
valid_set_2 = KITTIFlow(cfg.root_kitti12, n_frames=cfg.val_n_frames,
transform=valid_input_transform,
)
valid_set = ConcatDataset([valid_set_1, valid_set_2])
else:
raise NotImplementedError(cfg.type)
return train_set, valid_set
浅拷贝创建一个新对象,然后将原始对象中的元素复制到新对象中。但是,如果原始对象包含嵌套对象(例如列表中包含了另一个列表),则浅拷贝只复制了嵌套对象的引用,而不是创建嵌套对象的新实例。
深度拷贝创建一个新对象,并递归地复制原始对象及其所有嵌套对象的副本。这意味着,即使原始对象包含嵌套对象,深度拷贝也会创建嵌套对象的新实例,而不仅仅是复制引用。