论文阅读
- paper
- github
- 论文阅读笔记
- AOT源码解析1-数据集处理
- AOT源码解析2-encoder+decoder
- AOT源码解析3-模型训练
- AOT源码解析4.1-model主体
- AOT源码解析4.2-model主体
- AOT源码解析4.3-model主体
- AOT源码解析4.4-model主体
- AOT源码解析4.5-model主体
代码分析
1静态数据处理
视频目标分割中的静态数据处理,通常是将单帧静态图像按照旋转、平移、裁剪、缩放、翻转等数据增强操作获得一系列不同的图像,再将这些图像组合在一起当成一段训练视频。除此之外,部分论文还会选择将生成得到的两个图像视频合并重叠在一起,得到一个更为复杂的新训练数据,以此来提升模型的鲁棒性。
1.1引入包
#用于从 __future__ 模块中导入特定的特性,使得这些特性在当前的 Python 版本中可用,即使这些特性在更高版本的 Python 中才被默认引入。
from __future__ import division
import os
#允许查找全文文件
from glob import glob
#用于编码和解码json数据
import json
import random
import cv2
from PIL import Image
import numpy as np
import torch
#Dataset 是一个抽象类,用于定义自定义数据集的接口。通过继承 torch.utils.data.Dataset 并实现特定的方法
from torch.utils.data import Dataset
#包含多种类型的图像变换:缩放、裁剪、归一化等
import torchvision.transforms as TF
import dataloaders.image_transforms as IT
#在 OpenCV 中,某些操作可能会使用多线程来加速处理。cv2.setNumThreads() 函数允许你指定 OpenCV 应该使用多少线程来执行这些操作。
#0表示 OpenCV 应该使用所有可用的硬件线程,它会自动选择最佳的线程数
cv2.setNumThreads(0)
1.2 继承Dataset类
"""
===============================================Dataset用法==============================================================
1、Dataset中常用的特定方法:__init__进行初始化,__len__得到数据数量,__getitem__获取数据。
2、一般流程为:先对dataset进行初始化,得到数据地址、数据名列表等。再使用数据名列表得到数据数量,并将该数量传给len。
3、getitem从len得到索引,索引从0~len-1.然后通过索引获取数据。然后使用Dataloader读取Dataset传出的数据。
------------------------------------------------------------------------------------------------------------------------
"""
1.3 数据初始化
#===========================================将静态图像转换成生成视频,用于预训练===============================================
class StaticTrain(Dataset):
def __init__(self,
root,
output_size,
seq_len=5,
max_obj_n=10,
dynamic_merge=True,
merge_prob=1.0,
aug_type='v1'): #aug_type是两种不同的图像增强方式
self.root = root #数据集根目录,根目录底下为imge、annotation,这两个下面是各个数据集的数据
self.clip_n = seq_len #序列长度
self.output_size = output_size #输出的图像大小
self.max_obj_n = max_obj_n #图像中的最大对象数量,用于生成one-hot编码,为每个对象生成唯一标识符
self.dynamic_merge = dynamic_merge
self.merge_prob = merge_prob #合并概率
self.img_list = list()
self.mask_list = list()
#==================================================获取数据列表===========================================================
#1、获取数据集名称,所有数据都按特定格式保存
#2、获取所有图像数据的名称,并存为list并确保所有img都有对应的mask
dataset_list = list() #用于保存使用到的数据集名称
lines = ['COCO', 'ECSSD', 'MSRA10K', 'PASCAL-S', 'PASCALVOC2012']#可能会用到的数据集,训练自己的数据时,在里面加入自己的数据名
for line in lines:
dataset_name = line.strip()#移除字符串开头和末尾的空白字符
img_dir = os.path.join(root, 'JPEGImages', dataset_name)
mask_dir = os.path.join(root, 'Annotations', dataset_name)
img_list = sorted(glob(os.path.join(img_dir, '*.jpg'))) + \
sorted(glob(os.path.join(img_dir, '*.png')))#搜索img_dir文件夹下的所有.jpg和.png结尾的文件
mask_list = sorted(glob(os.path.join(mask_dir, '*.png')))#同上
if len(img_list) > 0:#确保存在数据
if len(img_list) == len(mask_list):#确保所有img都有对应的mask
dataset_list.append(dataset_name)
self.img_list += img_list
self.mask_list += mask_list
print(f'\t{dataset_name}: {len(img_list)} imgs.')
else:
print(
f'\tPreTrain dataset {dataset_name} has {len(img_list)} imgs and {len(mask_list)} annots. Not match! Skip.'
)
else:
print(
f'\tPreTrain dataset {dataset_name} doesn\'t exist. Skip.')
print(
f'{len(self.img_list)} imgs are used for PreTrain. They are from {dataset_list}.'
)
#====================================================初始化数据增强=============================================================
#1、aug_type定义了增强模式。设为v1时,轻微改变图像的亮度、饱和度等;设为v2时,随机较大的改变图像亮度、饱和度,并随机灰度化图像,随机添加噪声;设为其他值时报错
#2、随机旋转、平移、缩放、剪切、插值、填充颜色
#3、随机裁剪和缩放
#4、将图像转换为张量
#5、将标签转换为one-hot编码
#6、进行归一化
self.aug_type = aug_type
self.pre_random_horizontal_flip = IT.RandomHorizontalFlip(0.5)
self.random_horizontal_flip = IT.RandomHorizontalFlip(0.3)
if self.aug_type == 'v1':
self.color_jitter = TF.ColorJitter(0.1, 0.1, 0.1, 0.03)
elif self.aug_type == 'v2':
self.color_jitter = TF.RandomApply(
[TF.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8)
self.gray_scale = TF.RandomGrayscale(p=0.2)
self.blur = TF.RandomApply([IT.GaussianBlur([.1, 2.])], p=0.3)
else:
assert NotImplementedError
self.random_affine = IT.RandomAffine(degrees=20,
translate=(0.1, 0.1),
scale=(0.9, 1.1),
shear=10,
resample=Image.BICUBIC,
fillcolor=(124, 116, 104))
#scale表示裁剪的区域为80%~100%,ratio: 裁剪区域的宽高比范围,图像缩放时使用的插值方法。Image.BICUBIC 是双三次插值,适用于缩小和放大图像。
base_ratio = float(output_size[1]) / output_size[0]
self.random_resize_crop = IT.RandomResizedCrop(
output_size, (0.8, 1),
ratio=(base_ratio * 3. / 4., base_ratio * 4. / 3.),
interpolation=Image.BICUBIC)
self.to_tensor = TF.ToTensor()
#max_obj_n代表数据集中最大的对象数量,用于确定one-hot编码向量的长度,使每个对象都有唯一编码
self.to_onehot = IT.ToOnehot(max_obj_n, shuffle=True)
self.normalize = TF.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))
1.4 获取数据长度
def __len__(self):
return len(self.img_list)
1.5 获取数据
- 1 总体框架
#========================================获取生成视频,并将图片样本进行合并,生成新的训练数据=====================================
#1、通过seq_num定义生成的视频数据长度
#2、通过旋转、裁剪等数据增强操作生成新视频
#3、若dynamic_merge=TRUE和merge_prob=1,则进行样本合并,生成新的数据
def __getitem__(self, idx):
sample1 = self.sample_sequence(idx) #获取生成的视频
if self.dynamic_merge and (sample1['meta']['obj_num'] == 0
or random.random() < self.merge_prob):
#选取随机索引,并且使得随机到的索引不等于当前索引
rand_idx = np.random.randint(len(self.img_list))
while (rand_idx == idx):
rand_idx = np.random.randint(len(self.img_list))
#获取第二个样本
sample2 = self.sample_sequence(rand_idx)
sample = self.merge_sample(sample1, sample2)
else:
sample = sample1
return sample
def merge_sample(self, sample1, sample2, min_obj_pixels=100):
return _merge_sample(sample1, sample2, min_obj_pixels, self.max_obj_n)
- 2 生成视频
def load_image_in_PIL(self, path, mode='RGB'):
img = Image.open(path)
img.load() # Very important for loading large image,调用 load 方法将图像数据加载到内存中
return img.convert(mode)
def sample_sequence(self, idx):
#读取图像,并读取为PIL格式
img_pil = self.load_image_in_PIL(self.img_list[idx], 'RGB')
mask_pil = self.load_image_in_PIL(self.mask_list[idx], 'P')
frames = []
masks = []
#随机水平翻转
img_pil, mask_pil = self.pre_random_horizontal_flip(img_pil, mask_pil)
# img_pil, mask_pil = self.pre_random_vertical_flip(img_pil, mask_pil)
#===============================================通过循环生成视频===================================================
for i in range(self.clip_n):
img, mask = img_pil, mask_pil
if i > 0:
img, mask = self.random_horizontal_flip(img, mask)
img, mask = self.random_affine(img, mask)
img = self.color_jitter(img)
img, mask = self.random_resize_crop(img, mask)
if self.aug_type == 'v2':
img = self.gray_scale(img)
img = self.blur(img)
#---------------------掩码处理--------------------------
#将mask转为numpy数组
mask = np.array(mask, np.uint8)
if i == 0:
mask, obj_list = self.to_onehot(mask)
obj_num = len(obj_list)
else:
#返回的maskshape为(max_obj_n+1,height,weight),其中0通道为背景掩码,剩下每一个通道代表一个对象的mask
mask, _ = self.to_onehot(mask, obj_list)
mask = torch.argmax(mask, dim=0, keepdim=True)
frames.append(self.normalize(self.to_tensor(img)))#frames存储原始视频帧
masks.append(mask) #masks存储模板值
sample = {
'ref_img': frames[0], #参考图像,即序列中的第一个图像
'prev_img': frames[1],
'curr_img': frames[2:],
'ref_label': masks[0],
'prev_label': masks[1],
'curr_label': masks[2:]
}
sample['meta'] = {
'seq_name': self.img_list[idx],
'frame_num': 1,
'obj_num': obj_num
}
return sample
- 3 样本合并
def _get_images(sample):
return [sample['ref_img'], sample['prev_img']] + sample['curr_img']
def _get_labels(sample):
return [sample['ref_label'], sample['prev_label']] + sample['curr_label']
#===============================================将两个样本融合为一个========================================================
def _merge_sample(sample1, sample2, min_obj_pixels=100, max_obj_n=10):
#提取所有的图像和掩码
sample1_images = _get_images(sample1)
sample2_images = _get_images(sample2)
sample1_labels = _get_labels(sample1)
sample2_labels = _get_labels(sample2)
#obj_idx: 一个用于索引对象的张量,范围从 0 到 max_obj_n * 2。
#selected_idx 和 selected_obj: 用于存储选择的对象索引和对象本身。
obj_idx = torch.arange(0, max_obj_n * 2 + 1).view(max_obj_n * 2 + 1, 1, 1)
selected_idx = None
selected_obj = None
all_img = []
all_mask = []
#========================通过合并不同样本,创建新的训练样本======================
for idx, (s1_img, s2_img, s1_label, s2_label) in enumerate(
zip(sample1_images, sample2_images, sample1_labels,
sample2_labels)):
s2_fg = (s2_label > 0).float()
s2_bg = 1 - s2_fg
merged_img = s1_img * s2_bg + s2_img * s2_fg
merged_mask = s1_label * s2_bg.long() + (
(s2_label + max_obj_n) * s2_fg.long())
merged_mask = (merged_mask == obj_idx).float()
if idx == 0:
after_merge_pixels = merged_mask.sum(dim=(1, 2), keepdim=True)
selected_idx = after_merge_pixels > min_obj_pixels
selected_idx[0] = True
obj_num = selected_idx.sum().int().item() - 1
selected_idx = selected_idx.expand(-1,
s1_label.size()[1],
s1_label.size()[2])
if obj_num > max_obj_n:
selected_obj = list(range(1, obj_num + 1))
random.shuffle(selected_obj)
selected_obj = [0] + selected_obj[:max_obj_n]
merged_mask = merged_mask[selected_idx].view(obj_num + 1,
s1_label.size()[1],
s1_label.size()[2])
if obj_num > max_obj_n:
merged_mask = merged_mask[selected_obj]
merged_mask[0] += 0.1
merged_mask = torch.argmax(merged_mask, dim=0, keepdim=True).long()
all_img.append(merged_img)
all_mask.append(merged_mask)
sample = {
'ref_img': all_img[0],
'prev_img': all_img[1],
'curr_img': all_img[2:],
'ref_label': all_mask[0],
'prev_label': all_mask[1],
'curr_label': all_mask[2:]
}
sample['meta'] = sample1['meta']
sample['meta']['obj_num'] = min(obj_num, max_obj_n)
return sample
2 视频数据处理
这里的做法是在原始视频段上截取一段随机长度的子片段,然后再将该子片段随机分成4段更小的片段,小片段的长度不一。将随机得到的小片段进行数据增强和数据合并(与1 静态数据处理中的一致)后,作为训练数据训练model。ref_frame是参考帧,半监督视频分割网络依据参考帧进行分割;curr_frame是当前帧,用于分割预测;pre_frame是当前帧的前一帧。这里值得注意的是:参考帧要求前景mask不能过小,同时必须包含当前帧中的所有对象。
2.1 数据初始化-父类VOSTrain
#========================================从视频序列中采样序列帧,并进行数据增强和预处理==========================================
class VOSTrain(Dataset):
def __init__(self,
image_root,
label_root,
imglistdic,
transform=None,
rgb=True,
repeat_time=1,
rand_gap=3,
seq_len=5,
rand_reverse=True,
dynamic_merge=True,
enable_prev_frame=False,
merge_prob=0.3,
max_obj_n=10):
self.image_root = image_root #图像根目录
self.label_root = label_root #label根目录
self.rand_gap = rand_gap #随机采样视频间隔
self.seq_len = seq_len #视频段长度
self.rand_reverse = rand_reverse #随机逆序遍历列表,原代码的概率是0.5
self.repeat_time = repeat_time #从视频段中重复采取视频段的次数
self.transform = transform #图像变换
self.dynamic_merge = dynamic_merge #动态合并
self.merge_prob = merge_prob #动态合并概率
self.enable_prev_frame = enable_prev_frame #使用前一帧
self.max_obj_n = max_obj_n #图像的最大对象数量
self.rgb = rgb #rgb图像
self.imglistdic = imglistdic #里面包含各个视频名称,每个视频名称对应的是相应的图像名列表和label名列表
self.seqs = list(self.imglistdic.keys()) #包含的是各个视频名称
print('Video Num: {} X {}'.format(len(self.seqs), self.repeat_time))
2.2 数据初始化-子类DAVIS2017_Train
- DAVIS数据集格式
class DAVIS2017_Train(VOSTrain):
def __init__(self,
split=['train'],
root='./DAVIS',
transform=None,
rgb=True,
repeat_time=1,
full_resolution=True,
year=2017,
rand_gap=3,
seq_len=5,
rand_reverse=True,
dynamic_merge=True,
enable_prev_frame=False,
max_obj_n=10,
merge_prob=0.3):
#选择不同分辨率,这个主要是DAVIS数据集自带的分辨率类别。
#具体DAVIS结构见CSDN图
if full_resolution:
resolution = 'Full-Resolution'
if not os.path.exists(os.path.join(root, 'JPEGImages',
resolution)):
print('No Full-Resolution, use 480p instead.')
resolution = '480p'
else:
resolution = '480p'
image_root = os.path.join(root, 'JPEGImages', resolution)
label_root = os.path.join(root, 'Annotations', resolution)
seq_names = [] #存储的是视频名称
for spt in split:
#这里得到的地址是:root/ImageSets/2017/train.txt
with open(os.path.join(root, 'ImageSets', str(year),
spt + '.txt')) as f:
seqs_tmp = f.readlines()
#创建一个新的列表,其中包含seqs_tmp,列表中每个元素的副本,但已经去除了每个元素两端的空白字符。
seqs_tmp = list(map(lambda elem: elem.strip(), seqs_tmp))
seq_names.extend(seqs_tmp)
imglistdic = {}
for seq_name in seq_names:
#os.path.join(image_root, seq_name)存储的是每个视频文件夹的地址
#images和labels存储的是各个图像名列表
images = list(
np.sort(os.listdir(os.path.join(image_root, seq_name))))
labels = list(
np.sort(os.listdir(os.path.join(label_root, seq_name))))
imglistdic[seq_name] = (images, labels) #存储的是每个视频名称类别对应一个图像路径和一个标签路径
super(DAVIS2017_Train, self).__init__(image_root,
label_root,
imglistdic,
transform,
rgb,
repeat_time,
rand_gap,
seq_len,
rand_reverse,
dynamic_merge,
enable_prev_frame,
merge_prob=merge_prob,
max_obj_n=max_obj_n)
2.3 获得数据长度
def __len__(self):
return int(len(self.seqs) * self.repeat_time)
2.4 获得数据
- 主体框架
def __getitem__(self, idx):
#获得样本
sample1 = self.sample_sequence(idx)
#进行样本合并
if self.dynamic_merge and (sample1['meta']['obj_num'] == 0
or random.random() < self.merge_prob):
rand_idx = np.random.randint(len(self.seqs))
while (rand_idx == (idx % len(self.seqs))):#确保随机生成的索引与当前索引不重复
rand_idx = np.random.randint(len(self.seqs))
sample2 = self.sample_sequence(rand_idx)
sample = self.merge_sample(sample1, sample2)
else:
sample = sample1
return sample
def merge_sample(self, sample1, sample2, min_obj_pixels=100):
return _merge_sample(sample1, sample2, min_obj_pixels, self.max_obj_n)
- 获取样本
def sample_sequence(self, idx):
#通过取余的方法,选取不同的视频类别,并读取数据
idx = idx % len(self.seqs)
seqname = self.seqs[idx]
imagelist, lablist = self.imglistdic[seqname]
frame_num = len(imagelist) #视频中包含的帧数
#随机逆序遍历列表
if self.rand_reverse:
imagelist, lablist = self.reverse_seq(imagelist, lablist)
is_consistent = False #一致性开关,这个用于控制当前帧和前一帧的对象在参考帧中都存在
max_try = 5 #最大循环次数
try_step = 0 #当前循环次数
#只要没有找到一致的帧序列且尝试次数小于最大尝试次数,就继续循环。
#为的是通过随机采样和一致性检查,从视频序列中选取一组有意义的帧,这些帧将被用于后续的视频对象分割任务。
while (is_consistent is False and try_step < max_try):
try_step += 1
# ==============================================随机生成当前间隔===============================================
#在这里seq_len为5,get_curr_gaps会随机生成4(seq_len-1)次当前gap,这些gap是1~4中随机生成。total_gap是所有当前gap的和,相当于最终训练时获得的视频长度
curr_gaps, total_gap = self.get_curr_gaps(self.seq_len - 1)
#==================如果 self.enable_prev_frame 为真,表示允许随机采样前一帧。====================
if self.enable_prev_frame: # prev frame is randomly sampled
# get prev frame
#这里的做法是:视频总长度-训练时获得的视频长度,然后在这个范围内随机选取一个值。这样做的原因就是为了当最终裁取的视频段不会逾越过原视频
prev_index = self.get_prev_index(lablist, total_gap)#根据总间隔获取前一帧的索引。
#获取图像和标签。
prev_image, prev_label = self.get_image_label(
seqname, imagelist, lablist, prev_index)
#提取出唯一的对象标识符
prev_objs = list(np.unique(prev_label))
# get curr frames
#根据前一帧索引和间隔生成当前帧的索引列表。即以前一帧为star,以star+total_gap结尾,截取一段视频段。然后这段视频段被分成seq_len-1个小段
curr_indices = self.get_curr_indices(lablist, prev_index,
curr_gaps)
#遍历 curr_indices,获取当前帧的图像和标签,并收集所有对象
curr_images, curr_labels, curr_objs = [], [], []
for curr_index in curr_indices:
#获取图像和标签
curr_image, curr_label = self.get_image_label(
seqname, imagelist, lablist, curr_index)
c_objs = list(np.unique(curr_label))
curr_images.append(curr_image)
curr_labels.append(curr_label)
curr_objs.extend(c_objs)
#收集前一帧和当前帧的所有对象
objs = list(np.unique(prev_objs + curr_objs))
start_index = prev_index
end_index = max(curr_indices)
# get ref frame
_try_step = 0
ref_index = self.get_ref_index_v2(seqname, lablist)#参考帧随机采样
#如果参考帧索引在前一帧和当前帧索引范围内,重新生成,直到找到一个合适的参考帧。
while (ref_index > start_index and ref_index <= end_index
and _try_step < max_try):
_try_step += 1
ref_index = self.get_ref_index_v2(seqname, lablist)#随机获取参考帧索引,这里要求参考帧的前景mask不能过小
ref_image, ref_label = self.get_image_label(
seqname, imagelist, lablist, ref_index)
ref_objs = list(np.unique(ref_label))
else: # prev frame is next to ref frame,如果 self.enable_prev_frame 为假,表示前一帧是参考帧的下一帧。
# get ref frame,直接使用参考帧索引获取当前帧的索引列表和图像标签。
ref_index = self.get_ref_index_v2(seqname, lablist)
ref_image, ref_label = self.get_image_label(
seqname, imagelist, lablist, ref_index)
ref_objs = list(np.unique(ref_label))
# get curr frames
curr_indices = self.get_curr_indices(lablist, ref_index,
curr_gaps)
curr_images, curr_labels, curr_objs = [], [], []
for curr_index in curr_indices:
curr_image, curr_label = self.get_image_label(
seqname, imagelist, lablist, curr_index)
c_objs = list(np.unique(curr_label))
curr_images.append(curr_image)
curr_labels.append(curr_label)
curr_objs.extend(c_objs)
objs = list(np.unique(curr_objs))
prev_image, prev_label = curr_images[0], curr_labels[0]
curr_images, curr_labels = curr_images[1:], curr_labels[1:]
is_consistent = True#假设帧序列是一致的。
#遍历所有对象,检查它们是否在参考帧中也存在。如果有任何对象在参考帧中不存在,则将 is_consistent 设置为假,并跳出循环。
for obj in objs:
if obj == 0:
continue
if obj not in ref_objs:
is_consistent = False
break
# get meta info
obj_num = list(np.sort(ref_objs))[-1]
sample = {
'ref_img': ref_image,
'prev_img': prev_image,
'curr_img': curr_images,
'ref_label': ref_label,
'prev_label': prev_label,
'curr_label': curr_labels
}
sample['meta'] = {
'seq_name': seqname,
'frame_num': frame_num,
'obj_num': obj_num
}
if self.transform is not None:
sample = self.transform(sample)
return sample
- get_curr_gaps