large mask inpainting (LaMa)图像修复

论文:Resolution-robust Large Mask Inpainting with Fourier Convolutions
github:https://github.com/advimman/lama
cv_fft_inpainting_lama

目录:

1.FourierUnit
2.SpectralTransform
3.Fast Fourier convolution (FFC)
4.refinement

large mask inpainting (LaMa) 模型:

  • 使用了包含 FFC的新网络结构,具有更广的感受野;
  • 高感受野感知损失函数;
  • 大的训练mask,解锁了前两者的潜力。

LaMa模型结构1. 输入为 (b,4,h,w)

masked_img = img * (1 - mask)
masked_img = torch.cat([masked_img, mask], dim=1)

2.在输入上下左右各加3个像素的padding,使用kernel_size=7的卷积计算后,以保持h,w不变。
3.经过三次下采样,最后输出为 x_l:(b,128,h/8,w/8),x_g:(b,384,h/8,w/8) 。
4.经过6-18个残差模块。
5.经过三次上采样,生成图像。

class FFCResNetGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
                 padding_type='reflect', activation_layer=nn.ReLU,
                 up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True),
                 init_conv_kwargs={}, downsample_conv_kwargs={}, resnet_conv_kwargs={},
                 spatial_transform_layers=None, spatial_transform_kwargs={},
                 add_out_act=True, max_features=1024, out_ffc=False, out_ffc_kwargs={}):
        assert (n_blocks >= 0)
        super().__init__()

        model = [nn.ReflectionPad2d(3),
                 FFC_BN_ACT(input_nc, ngf, kernel_size=7, padding=0, norm_layer=norm_layer,
                            activation_layer=activation_layer, **init_conv_kwargs)]
        # x_l:(b,64,h,w) ,x_g:0
        ### downsample
        for i in range(n_downsampling):
            mult = 2 ** i
            if i == n_downsampling - 1:
                cur_conv_kwargs = dict(downsample_conv_kwargs)
                cur_conv_kwargs['ratio_gout'] = resnet_conv_kwargs.get('ratio_gin', 0)
            else:
                cur_conv_kwargs = downsample_conv_kwargs
            model += [FFC_BN_ACT(min(max_features, ngf * mult),
                                 min(max_features, ngf * mult * 2),
                                 kernel_size=3, stride=2, padding=1,
                                 norm_layer=norm_layer,
                                 activation_layer=activation_layer,
                                 **cur_conv_kwargs)]
        # x_l:(b,128,h/8,w/8),x_g:(b,384,h/8,w/8) 
        mult = 2 ** n_downsampling
        feats_num_bottleneck = min(max_features, ngf * mult)

        ### resnet blocks
        for i in range(n_blocks):
            cur_resblock = FFCResnetBlock(feats_num_bottleneck, padding_type=padding_type, activation_layer=activation_layer,
                                          norm_layer=norm_layer, **resnet_conv_kwargs)
            if spatial_transform_layers is not None and i in spatial_transform_layers:
                cur_resblock = LearnableSpatialTransformWrapper(cur_resblock, **spatial_transform_kwargs)
            model += [cur_resblock]
        # x_l:(b,128,h/8,w/8),x_g:(b,384,h/8,w/8) 
        model += [ConcatTupleLayer()]  
        # (b,512,h/8,w/8)

        ### upsample
        for i in range(n_downsampling):
            mult = 2 ** (n_downsampling - i)
            model += [nn.ConvTranspose2d(min(max_features, ngf * mult),
                                         min(max_features, int(ngf * mult / 2)),
                                         kernel_size=3, stride=2, padding=1, output_padding=1),
                      up_norm_layer(min(max_features, int(ngf * mult / 2))),
                      up_activation]
        # (b,64,h,w)
        if out_ffc:
            model += [FFCResnetBlock(ngf, padding_type=padding_type, activation_layer=activation_layer,
                                     norm_layer=norm_layer, inline=True, **out_ffc_kwargs)]

        model += [nn.ReflectionPad2d(3),
                  nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
         # (b,3,h,w)
        if add_out_act:
            model.append(get_activation('tanh' if add_out_act is True else add_out_act))
        self.model = nn.Sequential(*model)

    def forward(self, input):
        return self.model(input)

mask示例:
1.make_random_irregular_mask
在这里插入图片描述
2.make_random_rectangle_mask
在这里插入图片描述
3.make_random_superres_mask
在这里插入图片描述
4.DumbAreaMaskGenerator
在这里插入图片描述
5.OutpaintingMaskGenerator
在这里插入图片描述
6.RandomSegmentationMaskGenerator
需要使用detectron2,略。

  • 16
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值