手撕代码deep image matting(6):dataset(2)

class DIMDataset(Dataset):
    def __init__(self, split):
        self.split = split

        filename = '{}_names.txt'.format(split)
        with open(filename, 'r') as file:
            self.names = file.read().splitlines()#按行读取文件并存储在names变量中

        self.transformer = data_transforms[split]

    def __getitem__(self, i):
        name = self.names[i]
        fcount = int(name.split('.')[0].split('_')[0])
        bcount = int(name.split('.')[0].split('_')[1])
        im_name = fg_files[fcount]
        bg_name = bg_files[bcount]
        img, alpha, fg, bg = process(im_name, bg_name)

        # crop size 320:640:480 = 1:1:1
        different_sizes = [(320, 320), (480, 480), (640, 640)]
        crop_size = random.choice(different_sizes)

        trimap = gen_trimap(alpha)
        x, y = random_choice(trimap, crop_size)
        img = safe_crop(img, x, y, crop_size)
        alpha = safe_crop(alpha, x, y, crop_size)

        trimap = gen_trimap(alpha)

        # Flip array left to right randomly (prob=1:1)
        if np.random.random_sample() > 0.5:
            img = np.fliplr(img)
            trimap = np.fliplr(trimap)
            alpha = np.fliplr(alpha)

        x = torch.zeros((4, im_size, im_size), dtype=torch.float)
        img = img[..., ::-1]  # RGB
        img = transforms.ToPILImage()(img)#将数据转化成PIL Image类型
        img = self.transformer(img)
        x[0:3, :, :] = img
        x[3, :, :] = torch.from_numpy(trimap.copy() / 255.)

        y = np.empty((2, im_size, im_size), dtype=np.float32)
        y[0, :, :] = alpha / 255.
        mask = np.equal(trimap, 128).astype(np.float32)
        y[1, :, :] = mask

        return x, y

    def __len__(self):
        return len(self.names)

继续拆解dataset的代码,先把完整的放出来占用一下篇幅然后一点一点的往下拆。

上篇文章里把整个process函数的功用说完了,就把产生的返回值在这里再列出来方便后续讲解。

img:合成的前景和背景图

alpha:前景蒙版图

fg:处理过后的前景图片

bg:处理过后的背景图片

接下来从cropsize的注释开始往下搞起。

 # crop size 320:640:480 = 1:1:1
        different_sizes = [(320, 320), (480, 480), (640, 640)]
        crop_size = random.choice(different_sizes)

        trimap = gen_trimap(alpha)
        x, y = random_choice(trimap, crop_size)
        img = safe_crop(img, x, y, crop_size)
        alpha = safe_crop(alpha, x, y, crop_size)

        trimap = gen_trimap(alpha)

关于这里为什么使用这些尺寸,论文里面有提到:

Although our training dataset has 49,300 images, there are only 493 unique objects. To avoid overfitting as well as to leverage the training data more effectively, we use several training strategies. First, we randomly crop 320×320 (image, trimap) pairs centered on pixels in the unknown regions. This increases our sampling space. Second, we also crop training pairs with different sizes (e.g. 480×480, 640×640) and resize them to 320×320. This makes our method more robust to scales and helps the network better learn context and semantics.Third, flipping is performed randomly on each training pair.Fourth, the trimaps are randomly dilated from their ground truth alpha mattes, helping our model to be more robust to the trimap placement. Finally, the training inputs are recreated randomly after each training epoch.

翻译过来就是:为了保证更改有效的训练数据采用了几种训练策略,首先以未知区域的像素为中心随机裁剪320X320的图像,然后裁剪不同大小的训练对(480×480,640×640),并将它们的大小调整为320×320。第三就是进行随机翻转,最后在每个训练的eopch结束后重新创建训练的输入

回到代码里面,原本以为random.choice就是对这三对尺寸进行随机的选择。选择完毕之后

trimap = gen_trimap(alpha),这里trimap保存的是什么数据

按照字面理解来看,trimap这个东西在抠图的发展来说是一个划时代的创造,说白了就是在alpha值的拟合的时候添加一个约束条件:确定是前景的为1确定是背景为0,模棱两可的地方为0.5也就是灰的,抠图的算法就是把0.5的那一部分灰的和整个的边界的拟合。整个图片的文件夹路径已经把前景、背景、前景蒙版值都用到了,可能会有个trimap的位置需要被使用到。(merged是前景和背景合成的,其实这个图片目前来看的作用就是把前景和背景单独拎出来再组合一圈)

 这里单独看看gen_trimap。 

def gen_trimap(alpha):
    k_size = random.choice(range(1, 5))
    iterations = np.random.randint(1, 20)
    kernel = cv.getStructuringElement(cv.MORPH_ELLIPSE, (k_size, k_size))
    dilated = cv.dilate(alpha, kernel, iterations)
    eroded = cv.erode(alpha, kernel, iterations)
    trimap = np.zeros(alpha.shape)
    trimap.fill(128)
    trimap[eroded >= 255] = 255
    trimap[dilated <= 0] = 0
    return trimap

k_size保存从1到5之间随机一个数,iteration保存从1到20之间任意一个整数。(明明是一个意思为什么写成两个样子,秀),在这句kernel = cv.getStructuringElement(cv.MORPH_ELLIPSE, (k_size, k_size)),cv.getStructuringElement返回的是指定形状和尺寸的结构元素,cv.MORPH_ELLIPSE是指椭圆形的形状,尺寸是(k_size, k_size),默认锚点取点位于中心。

 dilated = cv.dilate(alpha, kernel, iterations)和 eroded = cv.erode(alpha, kernel, iterations),这两句就是把alpha这张图片进行了一圈膨胀和腐蚀,kernel是椭圆形的尺寸是(k_size, k_size)的内核,k_size是随机的1到5中间的整数,iteration是膨胀腐蚀的迭代次数,数值从1到20之间随机取整。膨胀的结果赋值给了dilated,腐蚀的结果赋给了eroded,但是alpha本身没变。

  在这几句中,trimap = np.zeros(alpha.shape)
    trimap.fill(128)
    trimap[eroded >= 255] = 255
    trimap[dilated <= 0] = 0

在这里面先初始化了一个全部是128的trimap,这里面使用了np.array的内置索引,由于trimap和eroded以及dilated是同种尺寸的矩阵,所以在eroded大于255的位置上trimap对应位置的数值赋值255,dilated《=0的位置上对应的trimap对应位置为0。这里面有个比较奇怪的地方,在实验的时候正常来说像素在128显示的应该是灰色,但是使用imshow出现的就是纯白,但是像素值还是按照代码的走向来弄的。这是个细节上的问题,但是在源代码源项目上就没有相关的问题出现。

再得到了想要的trimap图像之后,x, y = random_choice(trimap, crop_size) 这里面又一次使用了自定义的函数:random_choice

def random_choice(trimap, crop_size=(320, 320)):
    crop_height, crop_width = crop_size
    y_indices, x_indices = np.where(trimap == unknown_code)
    num_unknowns = len(y_indices)
    x, y = 0, 0
    if num_unknowns > 0:
        ix = np.random.choice(range(num_unknowns))
        center_x = x_indices[ix]
        center_y = y_indices[ix]
        x = max(0, center_x - int(crop_width / 2))
        y = max(0, center_y - int(crop_height / 2))
    return x, y

默认值的crop_size是(320,320),实际传入什么数值就调用相应的尺寸。crop_height和crop_width提取出来crop_size的对应数值作为宽和高。关于这句y_indices, x_indices = np.where(trimap == unknown_code),这里面unknown_code在config.py中设定为默认值128,np.where就是在找trimap元素值等于128的那一部分,返回的y_indices以及x_indices就是对应的y以及x坐标位置。实验出来是这个样子:

 由于像素值128对应的像素位置都是未知前景和背景的,所以num_unknowns保存的是对应像素的数量。为什么保存的是y_indices的数量,应该是x_indices和y_indices的个数是一样的,因为是对应点位的坐标,而且个数也是像素是128对应的点,所以保存哪个都无所谓。再往下看。

if num_unknowns > 0:
        ix = np.random.choice(range(num_unknowns))
        center_x = x_indices[ix]
        center_y = y_indices[ix]
        x = max(0, center_x - int(crop_width / 2))
        y = max(0, center_y - int(crop_height / 2))

当num_unkowns > 0 时进入判断(如果前面运行顺利的话肯定会出现几个128像素的点,应该说是一堆的点,这样的判断就是为了保证不发生意外),ix保存在0到num_unknowns之间的任意一个数,然后将center_x和center_y对应到下标为ix的位置上。 对比center_x 与 crop_width的1/2 的差值,大于0的话就把x赋值为差值,否则x为0 ,y也是如此思路的处理。到最后整个函数返回x和y。那么x和y到底起什么作用就得继续往下看。

 # crop size 320:640:480 = 1:1:1
        different_sizes = [(320, 320), (480, 480), (640, 640)]
        crop_size = random.choice(different_sizes)

        trimap = gen_trimap(alpha)
        x, y = random_choice(trimap, crop_size)
        img = safe_crop(img, x, y, crop_size)
        alpha = safe_crop(alpha, x, y, crop_size)

        trimap = gen_trimap(alpha)

x和y返回了之后,下一步又是一个函数:safe_crop。又要找定义的那个地方抠了。

def safe_crop(mat, x, y, crop_size=(im_size, im_size)):
    crop_height, crop_width = crop_size
    if len(mat.shape) == 2:
        ret = np.zeros((crop_height, crop_width), np.uint8)
    else:
        ret = np.zeros((crop_height, crop_width, 3), np.uint8)
    crop = mat[y:y + crop_height, x:x + crop_width]
    h, w = crop.shape[:2]
    ret[0:h, 0:w] = crop
    if crop_size != (im_size, im_size):
        ret = cv.resize(ret, dsize=(im_size, im_size), interpolation=cv.INTER_NEAREST)
    return ret

  img = safe_crop(img, x, y, crop_size) 先看看传入的参数。

img:合成的前景和背景图 对应函数的参数mat

x: 0或者是center_x 与 crop_width的1/2 的差值

y:0或者是center_y 与 crop_height的1/2 的差值 

crop_size : 随机选定的三个尺寸之一,im_size是默认的设定320

crop_height和crop_weight保存了crop_size的尺寸后,由于我们现在按照传入了合成后的图像img保存到了mat里面来进行对crop_size函数的分析,mat的尺寸就是3维(x,y,3),但是在后面函数的使用还有这一条:alpha = safe_crop(alpha, x, y, crop_size)  此时mat保存的是alpha也就是二值图像,所以根据mat的尺寸初始化ret的时候就需要加个判断:如果img.shape 的长度是2,就按照(crop_height,crop_width)的尺寸来初始化ret矩阵,也就是说此时的图像是二值图像;如果img.shape 的长度是其他值(也就差不多是3,图像的传入很难有别的维度),就按照(crop_height,crop_width,3)的尺寸来初始化ret矩阵。

crop = mat[y:y + crop_height, x:x + crop_width] 将mat在从y到y_crop_height以及x到x+crop_width的尺寸中的像素点赋给crop,然后将crop的宽和高保存给w和h,之后将ret从 [0,0] 到 [ h,w] 的 像素值都赋值为crop。按照正常思路来说,crop_width = crop_height = w = h  ,如果尺寸和(im_size,im_size)也就是(320,320)不相等的情况下就按照imsize的尺寸进行尺寸重塑。最后返回做好的ret。

这么多次折腾出来的ret到底是什么东西?有必要看看。

 这里面就有点奇怪了:之前并没考虑过一个问题,那就是  crop = mat[y:y + crop_height, x:x + crop_width] 句里面,一旦y到y + crop_height这个区域过了界怎么办?那么mat只能取到图片边缘处,就会导致最后的crop尺寸肯定会比crop_height要小。这里面也有可能完整的把图片尺寸取到320,320。从这里就是做到了论文里提到的随机裁剪并重组尺寸的策略,也就是safe_crop函数的功能。

在这之后alpha又经历了一次safe_crop也就是裁剪又重组成(320,320),在这之后又生成了一次trimap。试验一下看看这次的trimap会有什么不一样。由于x和y都是使用的相同的随机值,所以最后生成出来的图片尺寸都是完全相同的。

 

 此时img经过随机裁剪重塑为(320,320)后的效果:

 在下一步将alpha也就是蒙版值进行裁剪重塑后的alpha效果:

 

 最后一步将这么折腾完毕的alpha再生成trimap,就是这个造型:

 往下走,接下来就是这几句代码:

# Flip array left to right randomly (prob=1:1)
        if np.random.random_sample() > 0.5:
            img = np.fliplr(img)
            trimap = np.fliplr(trimap)
            alpha = np.fliplr(alpha)

np.random.random_sample()  取一个0到1之间的随机浮点数(不等于1),如果这个数字大于0.5,就对img、trimap、alpha都进行左右翻转。这就是论文中提到的最后一条策略:随机反转

最后就一口气把整个getitem看完。

        x = torch.zeros((4, im_size, im_size), dtype=torch.float)
        img = img[..., ::-1]  # RGB
        img = transforms.ToPILImage()(img)#将数据转化成PIL Image类型
        img = self.transformer(img)
        x[0:3, :, :] = img
        x[3, :, :] = torch.from_numpy(trimap.copy() / 255.)

        y = np.empty((2, im_size, im_size), dtype=np.float32)
        y[0, :, :] = alpha / 255.
        mask = np.equal(trimap, 128).astype(np.float32)
        y[1, :, :] = mask

        return x, y

开始的时候就出现一个跟之前完全不太一样的初始化: x = torch.zeros((4, im_size, im_size), dtype=torch.float) 之前使用np.zeros的创建思路是(im_size,im_size,3),为什么这里的通道数量放到了前面?这就得往下看了。

在这句 img = img[..., ::-1] 的后面官方加了个注释:RGB 这里面就有个让人脊背一凉的细节没注意到,那就是cv.imread 是按照BGR的格式读入图像,之前为了试验显示用的cv.imshow是将bgr数据重新以RGB的格式显示,这也就说明:必须要把颜色通道翻转回来才能保证RGB图像数据的输入。然后就把img从RGB图片变成了PIL图片,这样在读取的时候能完全按照RGB三通道读取。之后img就要经历一场transformer,也就是对图片的格式改进。

data_transforms = {
    'train': transforms.Compose([
        transforms.ColorJitter(brightness=0.125, contrast=0.125, saturation=0.125),#改变图片亮度
        transforms.ToTensor(),#转换成tensor然后归一化至0-1,只有PILImage图像能够做到归一化到0-1
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),#图片标准化,即先减均值,再除以标准差,前面的是均值后面的是标准差,由于图片是三个通道所以每一栏都有三个
    ]),
    'valid': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

这里面有一个细节,在把图片变成tensor格式的时候也就是toTensor这个函数打开的时候出现了一段话: 

 也就是说将PIL Image图片格式或者np.array格式变成tensor类型的时候,会改变尺寸格式,会由原来的(高,宽,通道)变成(通道,高,宽)。 那么img经过了变成PIL又经历了totensor归一化之后变得可以说干干净净白白胖胖的一个标准的三通道的tensor。到这里再看

x[0:3, :, :] = img
x[3, :, :] = torch.from_numpy(trimap.copy() / 255.) 

就明白为什么x的初始化要把通道数放在前面设定。这两个语句的目的就是使得x能够保证4通道的输入,前3个通道也就是0,1,2 保存经过transform后的img,最后一个通道保存trimap。由于img在经过transform的时候经过了转换变成了0到1的数据,trimap并没有进行,所以要将trimap的数据复制出来并除以255保证格式的一致性。再看看要返回的y是什么

  y = np.empty((2, im_size, im_size), dtype=np.float32)
  y[0, :, :] = alpha / 255.
  mask = np.equal(trimap, 128).astype(np.float32)
  y[1, :, :] = mask

其实这里面y在初始化的时候我个人认为用np.zeros也差不多,因为都是先进行了一波初始化然后把alpha/255的值赋值给了y的第一通道。第二通道赋值为mask,就是如果trimap对应的像素值等于128的时候,此时mask值为true,反之为false,再经过float32的转换变成0和1的形式。

 到这里总结一下整个的getitem的返回值:

x: 4通道的输入tensor,前三个为img的变换,最后一个通道为trimap

y : 第一通道为蒙版值alpha的归一化,第二通道为标记像素是否为128也就是是否是模棱两可的边界值

那么dataset的具体实现注释如下:

class DIMDataset(Dataset):
    def __init__(self, split):
        #split传入需要对接的数据及类型,然后根据split直接调用对应的文件,正常可以调用的就两种类型:train 和 valid
        self.split = split

        filename = '{}_names.txt'.format(split)
        with open(filename, 'r') as file:
            self.names = file.read().splitlines()#按行读取文件并存储在names变量中

        self.transformer = data_transforms[split]#根据split对应的值直接调用相对应的transform,transform实现两个功能:将对应图像改变亮度,转化成值在0-1之间的tensor,并按照相关数值标准化

    def __getitem__(self, i):
        name = self.names[i]
        fcount = int(name.split('.')[0].split('_')[0])
        bcount = int(name.split('.')[0].split('_')[1])
        im_name = fg_files[fcount]#前景文件名
        bg_name = bg_files[bcount]#背景文件名
        img, alpha, fg, bg = process(im_name, bg_name)#process根据前景文件名和背景文件名返回四个数据:前景和背景合成后的img,蒙版值alpha,前景图像fg和背景图像bg

        # crop size 320:640:480 = 1:1:1
        different_sizes = [(320, 320), (480, 480), (640, 640)]
        crop_size = random.choice(different_sizes)#随机选择三个尺寸的任意值

        trimap = gen_trimap(alpha)#根据输入的alpha 值返回trimap
        x, y = random_choice(trimap, crop_size) #根据随机选择的尺寸返回随机的起始裁剪坐标
        img = safe_crop(img, x, y, crop_size) #根据裁剪坐标以及输入的合成后的img返回随机裁剪过后的img
        alpha = safe_crop(alpha, x, y, crop_size)#返回和img对应的随机裁剪过后的alpha

        trimap = gen_trimap(alpha)#根据裁剪后的alpha返回新的trimap

        # Flip array left to right randomly (prob=1:1)
        if np.random.random_sample() > 0.5: #随机翻转
            img = np.fliplr(img)
            trimap = np.fliplr(trimap)
            alpha = np.fliplr(alpha)

        x = torch.zeros((4, im_size, im_size), dtype=torch.float)
        img = img[..., ::-1]  # RGB 由于cv.imread读取的数值是bgr,所以要将通道翻转才能生成rgb数据
        img = transforms.ToPILImage()(img)#将数据转化成PIL Image类型
        img = self.transformer(img)#将转换成PIL Image类型的img进行transform转换得到新的tensor数值
        x[0:3, :, :] = img #x的前三个通道赋值为img
        x[3, :, :] = torch.from_numpy(trimap.copy() / 255.)#x的第四个通道设为trimap

        y = np.empty((2, im_size, im_size), dtype=np.float32)
        y[0, :, :] = alpha / 255.#y的第一个通道设为alpha蒙版值
        mask = np.equal(trimap, 128).astype(np.float32)
        y[1, :, :] = mask#y的第二个通道设为trimap与128是否相同的结果,相同的点为1不相同为0

        return x, y

    def __len__(self):
        return len(self.names)

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值