手撕代码1:deep image matting (4) DIM模型结构

上一篇遗留下来的最大坑也就是整个论文最值钱的地方:dim的模型结构。先看整体的代码的结构框架

使用了Ctrl+F12生成出来的函数内部结构来看,变量这块实际上是缺失了,由于整体的代码长度太大,先看函数构造在细扣里面细节。init和forward是大部分继承pytorch的nn.Module模块的函数必须要具备的两个函数,init作为初始化函数,forward作为前向传播的函数。这里面有一个细节上的问题,在实际使用的时候并没有出现model.forwad()的代码,原因是直接使用函数模块调用DIMModel(参数)的话就会自动使用forward函数。

在进行彻底解构他之前,先祭出来论文的结构模型。

 看论文里面的训练模型来说,数据进入模型之后会经过一系列足够让它怀疑人生的操作:五层下采样,一次卷积加relu,五层上采样,最后再来一个纯卷积,最后出来的蒙版值和原来图片提供的蒙版值也就是ground truth进行损失函数的比对,反向传播回去修改参数,然后下一波按照上面的流程继续跑。看起来这个结构跟后面很多的大牛文章比好理解太多而且效果显著,那么就直接看模型函数的实现细节。

先看看init()初始化函数以及实际使用时的对比

 实际使用时

 在使用的时候,传入DIMModel()有四个参数:n_classes,in_channels,is_unpooling,pretrain ,这四个参数的话inchannel好理解,传入图片的通道数量,这篇文章的创新点就在于传了四个通道,其中第四个通道就是蒙版值用来进行抠图处理的,这个放到后面详细解读。unpooling使用在反卷积层面上的参数,在进行maxpooling之后反卷积回去的时候记住最大值的点以及它的位置,其余点的像素都变为0。这个在CNN概念之上采样,反卷积,Unpooling概念解释_g11d111的博客-CSDN博客_unpooling有详细讲解。剩下的问题就是:n_classes和pretrain到底是干什么的。往后再看看。

n_classes在整个模型的结构里面就出现了很少的次数,

 再看一眼segnetUp1这个函数,n_classes的目的昭然若揭。

 结合整个模型的架构来看,这个东西就是最后上采样输出的维度,默认值为1,也就是上采样出来的最后结果是1维的数据。这个和论文给出来的结构图也是一致的。

那么下一个大坑中的大坑就是:pretrain是干什么用的。直接点击运行之后,pretrain的默认值设为1,那么在模型的函数有一个判断pretrain的代码:

 这又是啥大坑?点开init_vgg16_params()函数看看究竟

    def init_vgg16_params(self, vgg16):
        blocks = [self.down1, self.down2, self.down3, self.down4, self.down5]

        ranges = [[0, 4], [5, 9], [10, 16], [17, 23], [24, 29]]
        features = list(vgg16.features.children())

        vgg_layers = []
        for _layer in features:
            if isinstance(_layer, nn.Conv2d):
                vgg_layers.append(_layer)

        merged_layers = []
        for idx, conv_block in enumerate(blocks):
            if idx < 2:
                units = [conv_block.conv1.cbr_unit, conv_block.conv2.cbr_unit]
            else:
                units = [
                    conv_block.conv1.cbr_unit,
                    conv_block.conv2.cbr_unit,
                    conv_block.conv3.cbr_unit,
                ]
            for _unit in units:
                for _layer in _unit:
                    if isinstance(_layer, nn.Conv2d):
                        merged_layers.append(_layer)

        assert len(vgg_layers) == len(merged_layers)

        for l1, l2 in zip(vgg_layers, merged_layers):
            if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d):
                if l1.weight.size() == l2.weight.size() and l1.bias.size() == l2.bias.size():
                    l2.weight.data = l1.weight.data
                    l2.bias.data = l1.bias.data

我不得不承认一个特别废物的现实,我对预训练模型的修改根本就不懂。。。。这里我大概其稍微琢磨了一下,就是把vgg16的预训练模型下载下来之后再用自己的模型掺进去修改一通完事了,因为我不清楚torchvision.model里面的vgg16模型源码是什么样子,所以这个坑,就只能留到以后来解答了。。。。

那么现在我们就假设他用自己的模型从头训练来一步一步的拆解他的模型设置。其实这里面他把很多的步骤都单独封装成类似于segnetUp这样的函数然后直接在模型就进行调用,本身nn.Module的特性就是必须要写forward函数而且直接调用封装好的模型函数本身就会自动走forward,这些代码的确是写的极其方便,至少比我亲身手写的强多了。

class segnetDown2(nn.Module):
    def __init__(self, in_size, out_size):
        super(segnetDown2, self).__init__()
        self.conv1 = conv2DBatchNormRelu(in_size, out_size, k_size=3, stride=1, padding=1)
        self.conv2 = conv2DBatchNormRelu(out_size, out_size, k_size=3, stride=1, padding=1)
        self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True)

    def forward(self, inputs):
        outputs = self.conv1(inputs)
        outputs = self.conv2(outputs)
        unpooled_shape = outputs.size()
        outputs, indices = self.maxpool_with_argmax(outputs)
        return outputs, indices, unpooled_shape


class segnetDown3(nn.Module):
    def __init__(self, in_size, out_size):
        super(segnetDown3, self).__init__()
        self.conv1 = conv2DBatchNormRelu(in_size, out_size, k_size=3, stride=1, padding=1)
        self.conv2 = conv2DBatchNormRelu(out_size, out_size, k_size=3, stride=1, padding=1)
        self.conv3 = conv2DBatchNormRelu(out_size, out_size, k_size=3, stride=1, padding=1)
        self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True)

    def forward(self, inputs):
        outputs = self.conv1(inputs)
        outputs = self.conv2(outputs)
        outputs = self.conv3(outputs)
        unpooled_shape = outputs.size()
        outputs, indices = self.maxpool_with_argmax(outputs)
        return outputs, indices, unpooled_shape


class segnetUp1(nn.Module):
    def __init__(self, in_size, out_size):
        super(segnetUp1, self).__init__()
        self.unpool = nn.MaxUnpool2d(2, 2)
        self.conv = conv2DBatchNormRelu(in_size, out_size, k_size=5, stride=1, padding=2, with_relu=False)

    def forward(self, inputs, indices, output_shape):
        outputs = self.unpool(input=inputs, indices=indices, output_size=output_shape)
        outputs = self.conv(outputs)
        return outputs

其实这里面有一个细节,在原版的论文模型架构,每一次下采样都是经历了三层的卷积和一个最大池化,在这里面最开始的segnetDown2走了就两层。好听点说节省资源,实际上大家懂的都懂。

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值