年初的时候的手撕DIM专栏写完之后,就该进行下一阶段:复现DIM,要知道一个模型一个算法如果没有从头复现出来的话,那就还是没学会整个模型架构的精髓。DIM作为matting领域的开山之作,因其模型极其的简单粗暴和效果极其的好使从而在当时甩了曾经的数学算法流派好几条街,对于有钱人有卡人来说,DIM的训练简直爽到死,但是我没有钱也没有卡,为了卡脸都可以不要的程度。那就只能老老实实从头复现,管它是不是能跑呢大不了我自己跑了。
那么既然要复现就有必要在回顾一下DIM训练的逻辑,然后从后往前一点一点的倒腾。按照之前手撕的顺序,整个DIM项目架构以训练文件train.py的角度梳理一下整个流程:
这就如同把大象装进冰箱一样简单粗暴的三步走战略,但是只要是个明白人都知道train_net才是最麻烦的。这里头就得把train_net再拆解一次:
如果我们自己要复现的话,肯定会挑着那些几乎不会调用其他各种乱七八糟的项目架构的那一块来搞定,那么目前来看也只有args初始化能满足这个条件。但是复现的时候还有几个关键的点:
- 如何保证最后复现出来的项目如同原版那般有序
- 如何保证能跑
也就是说:如果我们把整个功能都复现出来全塞到一个文件里面是完全可以的,但是手撕的过程中就发现:这里面的架构极为有序,对于程序阅读和查找来说特别方便,我们自己复现也必须得做到这个点才行,这对自己的后续检查也是极有好处。那么在上面的架构图就得标上对应使用函数的文件位置这样才好找。
现在来看,如果按照论文从头复现的话,先可着模型model进行复现是最简单直接的,但是要搞定模型的运行的话首要还是数据集的处理代码。因此现在的计划如下:
- args复现,这个是最简单的
- 数据集处理
- 模型搭建
- 训练验证功能
args虽然是最容易的,但是这个是在后面进行训练的时候才能进行彻底弄懂设定的思路,所以还是要老老实实先看看数据集是怎么进行处理的。说句实在话,数据集的处理是在手撕里面最费时间的,但是在后续看别人类似的代码就发现:大家全是这套。
class DIMDataset(Dataset):
def __init__(self, split):
self.split = split
filename = '{}_names.txt'.format(split)
with open(filename, 'r') as file:
self.names = file.read().splitlines()#按行读取文件内容
self.transformer = data_transforms[split]
def __getitem__(self, i):
name = self.names[i]
fcount = int(name.split('.')[0].split('_')[0])
bcount = int(name.split('.')[0].split('_')[1])
im_name = fg_files[fcount]
bg_name = bg_files[bcount]
img, alpha, fg, bg = process(im_name, bg_name)
# crop size 320:640:480 = 1:1:1
different_sizes = [(320, 320), (480, 480), (640, 640)]
crop_size = random.choice(different_sizes)
trimap = gen_trimap(alpha)
x, y = random_choice(trimap, crop_size)
img = safe_crop(img, x, y, crop_size)
alpha = safe_crop(alpha, x, y, crop_size)
trimap = gen_trimap(alpha)
# Flip array left to right randomly (prob=1:1)
if np.random.random_sample() > 0.5:
img = np.fliplr(img)
trimap = np.fliplr(trimap)
alpha = np.fliplr(alpha)
x = torch.zeros((4, im_size, im_size), dtype=torch.float)
img = img[..., ::-1] # RGB
img = transforms.ToPILImage()(img)#将数据转化成PIL Image类型
img = self.transformer(img)
x[0:3, :, :] = img
x[3, :, :] = torch.from_numpy(trimap.copy() / 255.)
y = np.empty((2, im_size, im_size), dtype=np.float32)
y[0, :, :] = alpha / 255.
mask = np.equal(trimap, 128).astype(np.float32)
y[1, :, :] = mask
return x, y
def __len__(self):
return len(self.names)
按照之前手撕代码对于数据集的处理的理解,这里就直接把整个流程整理出来方便后续的复现。
这些流程都是在getitem里面进行的,那么下面就要进行伪代码的复现然后一点一点拼接。
class DIM_DATASET(Dataset):
def __init__(self,split):
#根据输入的split获取文件名
self.split = split
file_name = '{}_name.txt'.format(split)
for f in open(file_name):
self.names += f
#初始化transforms
self.transforms = data_transformer(split)
def __getitem__(self, item):
#根据split进行数据初始化 获取前景文件名和背景文件名
#根据文件名进行处理
#随机裁剪尺寸,生成img,alpha,trimap
#图像随即翻转
#img进行transform
#返回值
def __len__(self):
#返回数据集长度
return len(self.names)
只要开写了就一定会有破点子。首先第一条:在写data_transformer这个组件的时候,按照之前手撕代码的思路来说,data_transformer是按照输入的值进行分辨运行哪种类型,那么换成函数进行的话是否也是可行?这里面就实验了两种思路的实现。
def data_transformer(split):
if split == "train":
return transforms.Compose([
transforms.ColorJitter(brightness=0.125, contrast=0.125, saturation=0.125),#改变图片亮度
transforms.ToTensor(),#转换成tensor然后归一化至0-1
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),#图片标准化,即先减均值,再除以标准差,前面的是标准值后面的是标准差
])
else :
return transforms.Compose([
transforms.ToTensor(),#转换成tensor然后归一化至0-1
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),#图片标准化,即先减均值,再除以标准差,前面的是标准值后面的是标准差
])
if __name__ == '__main__':
data_transforms = {
'train': transforms.Compose([
transforms.ColorJitter(brightness=0.125, contrast=0.125, saturation=0.125), # 改变图片亮度
transforms.ToTensor(), # 转换成tensor然后归一化至0-1
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # 图片标准化,即先减均值,再除以标准差,前面的是标准值后面的是标准差
]),
'valid': transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
trans = data_transformer("valid")
trans2 = data_transforms["valid"]
print(trans)
print(trans2)
实验结果就是:这两个东西的功能完全一致。但是在接下来打算进行调用的时候转念一想:这东西在其他文件是不是有调用的情况?果不其然。
这些出现的场合大多数就是import语句写了进去,写完了就在对应文件里面进行使用。import导入的时候几乎都是统一格式:
from data_gen import data_transforms, gen_trimap, fg_test_files, bg_test_files
再来看看新思路的data_transformers,只是把原来的字典格式变成了函数进行判定,不过仔细想想的话的确没有字典格式那么便捷,这里尝试下错误输入时的结果。
如果把函数写完整的话就会直接把错误输出的结果写出来并直接甩出来,但是函数也会正常运行下去,如果加一步触发异常的话,就是纯粹的白费事。所以还是老老实实按照原来的写法来吧。
上面的代码还有一处和原来的不太一样,就是读取文件名这行。原版使用的是file.read.splitline(),我再实现的时候想到的是逐行读取然后相加的方式,实验效果如下所示:
之前解析代码的时候对于train_txt这个文件命名方式有一个特别关键的点:前景位置 _ 背景位置.png,也就是说如果要获得前景和背景文件名找到对应的文件的话,就必须要把这里面每一条信息拆开,提取下横杠 _ 符号前面和后面的数字就能得到,并且dataset是根据每次循环对应的index循环编号进行对应寻找,那么就有必要再试试能否正确得出当前循环所对应的合成图的文件名。
结论就是:沿用原版的代码思路。 接下来就该进行合成图函数process()的流程复现。
def process(im_name, bg_name):
im = cv.imread(fg_path + im_name)
a = cv.imread(a_path + im_name, 0)
h, w = im.shape[:2]
bg = cv.imread(bg_path + bg_name)
bh, bw = bg.shape[:2]
wratio = w / bw
hratio = h / bh
ratio = wratio if wratio > hratio else hratio
if ratio > 1:
bg = cv.resize(src=bg, dsize=(math.ceil(bw * ratio), math.ceil(bh * ratio)), interpolation=cv.INTER_CUBIC)
return composite4(im, bg, a, w, h)
在分析的时候这里面发现了一个可以说是源代码在书写的时候比较混乱的地方:命名规则上面im代表合成图,fg代表前景bg代表背景,但是在process()里面im代表前景,bg代表背景。这个情况是否是根据使用场景不同造成的就只能看看这个函数的使用场景是否仅限于数据集生成。
既然如此的话在这里可以尝试一下规范命名的操作尝试复现:
前景fg,背景bg,蒙版值a,合成图im
那么按照上面的流程复现出来的伪代码就是这样子:
def process(fg_name,bg_name):
#根据相对应路径读取前景数据fg和背景数据 bg以及蒙版值a
fg = cv.imread(fg_path+fg_name)
bg = cv.imread(bg_path+bg_name)
a = cv.imread(a_path+fg_name,1)
#获取前景背景的宽和高并对背景进行适应性修改尺寸
h,w = fg.shape[:2]
bh,bw = bg.shape[:2]
wratio = w / bw
hratio = h / bh
ratio = wratio if wratio > hratio else hratio
if ratio > 1:
bg = cv.resize(src=bg, dsize=(math.ceil(bw * ratio), math.ceil(bh * ratio)), interpolation=cv.INTER_CUBIC)
#输入前景背景宽高进行合成
return composite4(fg,bg,a,h,w)
def composite4(fg, bg,a, h, w):
#前景转换成float32的np
fg = np.array(fg,np.float32)
bh,bw = bg.shape[:2]
x = 0
if bw>w:
x = np.random.randint(0,bw-w)
y = 0
if bh>h:
y = np.random.randint(0,bh-h)
bg = np.array(bg[y:y+h,x:x+w],np.float32)
alpha = np.zeros((h,w,1),np.float32)
alpha[:, :, 0] = a / 255.
im = fg*alpha+(1-alpha)*bg
im = im.astype(np.uint8)
return im,a,fg,bg
def __getitem__(self, i):
#根据split进行数据初始化 获取前景文件名和背景文件名
img_name = self.names[i]
fg_num = img_name.split('.')[0].split('_')[0]
bg_num = img_name.split('.')[0].split('_')[1]
fg_name = fg_files[fg_num]
bg_name = bg_files[bg_num]
#根据文件名进行处理
img,alpha,fg,bg = process(fg_name,bg_name)
写到这里就发现几个之前忽略掉的问题:
1.resize函数并不是裁剪,而是将大小缩放
2.x和y的顺序问题,因为后续重建alpha变量的时候x为横向y为纵向,实际上在np.array重组的时候是以几行几列进行的重建,和对应习惯坐标x和y的顺序正好相反
先到这里吧。。。复现真的好累