High-Fidelity Pluralistic Image Completion with Transformers

以往的image inpainting模型通常是基于CNN的网络来进行的(不仅是image inpainting,自从CNN出现以来,几乎所有的图像处理模型都是基于其进行的)。然而作者对于这一直以来的传统提出了一个问题:CNN在进行特征提取的时候,每个像素只与周围的卷积核大小的像素有关,这就使得其无法获得全局结构信息。因此,作者在这篇文章中首次提出使用最近大火的image transformer来解决这个问题。

本文基本思路:CNN可以生成细致的texture信息,然而难以获取全局结构信息;相反,transformer可以很容易的获取全局结构信息,因此本文通过transformer生成缺失部分的结构信息,在此基础上再通过CNN来生成texture信息。

虽然图像可以直接经过resize后输入到transformer进行处理,但是这么做的计算量非常巨大(图像长度的平方),所以在这里作者使用了两个小trick来处理图像:
1、将图像进行下采样以降低分辨率。经过下采样后,图像虽然丢失了大多数texture信息,但是绝大多数的结构信息是被保留下来的。
2、通过将RGB的ImageNet进行像素聚类,获得聚类中心,用最近的聚类中心的序列来表示输入图像的像素(这一步的聚类和查询个人没有看明白)。此外对于缺失部分,作者使用特殊的[mask]进行标记,这种特殊的标记就是后面transformer网络所要学习的目标。

通过前面的操作,作者成功将一张带有缺失的图像转化为了一个表征序列 X = { x 1 , x 2 , ⋅ ⋅ ⋅ , x L } X = \{x_1 ,x_2 ,··· ,x_L\} X={x1,x2,,xL}(之所以叫做表征序列,是因为到现在为止还没有进行特征提取的步骤,前面没有涉及到可学习网络)。之后通过一个映射网络,将前面的L维表征向量转化为d维特征向量,进行进一步降维。之后为了使得向量元素之间不至于丢失空间信息,作者加入了 position embeddings,这个embedding的作用就是把位置信息映射为位置空间的一个点,从而给输入的向量提供了一种相对位置信息。通过这一步,得到一个 E ∈ R L × d E \in R^{L \times d} ERL×d的输入,接下来就是将其输入到transformer网络中进行全连接的计算。

transformer由N个self-attention transformer层组成,每一层的结构是在这里插入图片描述
其中MSA是multi-head self-attention,MLP就是普通的全连接层。其中对于MSA,给定一个E:
在这里插入图片描述
其中 W Q j , W K j , W V j W^j_Q,W^j_K,W^j_V WQj,WKj,WVj是可以学习的映射, W O W_O WO也是一个可以学习的映射,用于融合前面的多头attention信息。这里与传统transformer不一样的地方是,传统的transformer在计算的时候是依照 auto-regressive,也就是依照马尔科夫性进行的,也就是在计算一个[mask]的时候只使用其前面位置的信息,而本文为了使得每个[mask]都能获得全局信息,打破了这种马尔科夫性,通过图示来看:
在这里插入图片描述
前两个是传统的transformer,后两个是本文的transformer。最后,通过将transformer结果映射到前面聚类结果的512像素点的一个概率(512类的分类),从而获得每个位置最大概率的像素结果,由此,整个transformer的优化目标就是:在这里插入图片描述
其中 X − π X_{- \pi} Xπ是所有可观测像素的信息, x π k x_{\pi_k} xπk是第 k k k个缺失位置的像素。这里优化了一个每个缺失与所有观测的最大似然概率。我感觉这个式子形象的解释了这里使用transformer的优势所在:每个像素都可以与全局可观测信息相联系。

获得了每个像素的最大似然,最后就是通过这些最大似然概率来取样。因为我们想让生成的像素之间有着一定的联系,因此为了获得更合理的取样结果,我们不能每个都独立的来取。这里作者引入了一种取样技巧:Gibbs sampling。简单来说,就是后一个取样概率会受到前一个取样结果的影响,用公式表示就是: p ( x π k ∣ X − π , X < π k , θ ) p(x_{\pi_k}|X_{−\pi} ,X < \pi_k ,\theta) p(xπkXπ,X<πk,θ)。我们一次取样,那么之前的[mask]位置就会被依次填充。之后我们将得到的结果通过查询前面的聚类结果,就可以获得一个 X ∈ R L × 3 X \in R^{L\times 3} XRL×3的表征序列,然后通过resize,将其变成一个 L × L × 3 \sqrt{L} \times \sqrt{L} \times 3 L ×L ×3的低分辨率完整图像。
至此,第一步,也是本文的核心步骤结束!
接下来的步骤就很好理解了,通过将得到的低分辨率完整图像双三次插值上采样到原始分辨率,然后和原来的缺失图像concatenation在一起进行编码→解码,最终得到更加清晰和富有texture的完整图像。在这一步,用到的损失是平常的重建损失和对抗损失。
最后需要说明的一点是,即使本文在使用transformer的时候,为了降低计算量,使用了大量的努力,然而最终的计算量依然是不容小视的,作者在FFHQ数据集中使用了8块TeslaV100,在Places2 and ImageNet中更是使用了惊人的32块TeslaV100,。

  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值