Learning Pyramid-Context Encoder Network for High-Quality Image Inpainting

1. Motivation

  • patch-based methods:不能提取到图像的高级特征,难以生成语义上很合理的结果;
  • deep learning based method:从图像密集的潜在特征中生成视觉上很合理的结果仍具有挑战性。

2. Approach

2.1 Network structure

生成器采用了编码和解码的结构,编码器采用的金字塔式的编码器,一个关键的地方是作者设计了Attention Transfer Network(ATN),可以让已知区域的特征迁移到缺失区域,实现更好的填充效果,反正这边的ATN我只看论文是不太懂,还是得看源代码,所以我将ATN的代码附在了下面。

解码器的生成图片包含不同尺度的图片,为了实现金字塔L_1损失(这里点在之后的loss部分介绍)。

class AtnConv(nn.Module):
  def __init__(self, input_channels=128, output_channels=64, groups=4, ksize=3, stride=1, rate=2, softmax_scale=10., fuse=True, rates=[1,2,4,8]):
    super(AtnConv, self).__init__()
    self.ksize = ksize
    self.stride = stride
    self.rate = rate 
    self.softmax_scale = softmax_scale
    self.groups = groups
    self.fuse = fuse
    if self.fuse:
      for i in range(groups):
        self.__setattr__('conv{}'.format(str(i).zfill(2)), nn.Sequential(
          nn.Conv2d(input_channels, output_channels//groups, kernel_size=3, dilation=rates[i], padding=rates[i]),
          nn.ReLU(inplace=True))
        )
    
  def forward(self, x1, x2, mask=None):
    """ Attention Transfer Network (ATN) is first proposed in
        Learning Pyramid Context-Encoder Networks for High-Quality Image Inpainting. Yanhong Zeng et al. In CVPR 2019.
      inspired by 
        Generative Image Inpainting with Contextual Attention, Yu et al. In CVPR 2018. 
    Args:
        x1: low-level feature maps with larger resolution.
        x2: high-level feature maps with smaller resolution. 
        mask: Input mask, 1 indicates holes. 
        ksize: Kernel size for contextual attention.
        stride: Stride for extracting patches from b.
        rate: Dilation for matching.
        softmax_scale: Scaled softmax for attention.
        training: Indicating if current graph is training or inference.
    Returns:
        torch.Tensor, reconstructed feature map. 
    """
    # get shapes
    x1s = list(x1.size())
    x2s = list(x2.size())

    # extract patches from low-level feature maps x1 with stride and rate
    kernel = 2*self.rate
    raw_w = extract_patches(x1, kernel=kernel, stride=self.rate*self.stride)
    raw_w = raw_w.contiguous().view(x1s[0], -1, x1s[1], kernel, kernel) # B*HW*C*K*K 
    # split tensors by batch dimension; tuple is returned
    raw_w_groups = torch.split(raw_w, 1, dim=0) 

    # split high-level feature maps x2 for matching 
    f_groups = torch.split(x2, 1, dim=0) 
    # extract patches from x2 as weights of filter
    w = extract_patches(x2, kernel=self.ksize, stride=self.stride)
    w = w.contiguous().view(x2s[0], -1, x2s[1], self.ksize, self.ksize) # B*HW*C*K*K
    w_groups = torch.split(w, 1, dim=0) 

    # process mask
    if mask is not None:
      mask = F.interpolate(mask, size=x2s[2:4], mode='bilinear', align_corners=True)
    else:
      mask = torch.zeros([1, 1, x2s[2], x2s[3]])
      if torch.cuda.is_available():
        mask = mask.cuda()
    # extract patches from masks to mask out hole-patches for matching 
    m = extract_patches(mask, kernel=self.ksize, stride=self.stride)
    m = m.contiguous().view(x2s[0], -1, 1, self.ksize, self.ksize)  # B*HW*1*K*K
    m = m.mean([2,3,4]).unsqueeze(-1).unsqueeze(-1)
    mm = m.eq(0.).float() # (B, HW, 1, 1)       
    mm_groups = torch.split(mm, 1, dim=0)

    y = []
    scale = self.softmax_scale
    padding = 0 if self.ksize==1 else 1
    for xi, wi, raw_wi, mi in zip(f_groups, w_groups, raw_w_groups, mm_groups):
      '''
      O => output channel as a conv filter
      I => input channel as a conv filter
      xi : separated tensor along batch dimension of front; 
      wi : separated patch tensor along batch dimension of back; 
      raw_wi : separated tensor along batch dimension of back; 
      '''
      # matching based on cosine-similarity
      wi = wi[0]
      escape_NaN = torch.FloatTensor([1e-4])
      if torch.cuda.is_available():
        escape_NaN = escape_NaN.cuda()
      # normalize 
      wi_normed = wi / torch.max(torch.sqrt((wi*wi).sum([1,2,3],keepdim=True)), escape_NaN)
      yi = F.conv2d(xi, wi_normed, stride=1, padding=padding)
      yi = yi.contiguous().view(1, x2s[2]//self.stride*x2s[3]//self.stride, x2s[2], x2s[3]) 

      # apply softmax to obtain 
      yi = yi * mi 
      yi = F.softmax(yi*scale, dim=1)
      yi = yi * mi
      yi = yi.clamp(min=1e-8)

      # attending 
      wi_center = raw_wi[0]
      yi = F.conv_transpose2d(yi, wi_center, stride=self.rate, padding=1) / 4.
      y.append(yi)
    y = torch.cat(y, dim=0)
    y.contiguous().view(x1s)
    # adjust after filling 
    if self.fuse:
      tmp = []
      for i in range(self.groups):
        tmp.append(self.__getattr__('conv{}'.format(str(i).zfill(2)))(y))
      y = torch.cat(tmp, dim=1)
    return y


# extract patches
def extract_patches(x, kernel=3, stride=1):
  if kernel != 1:
    x = nn.ZeroPad2d(1)(x)
  x = x.permute(0, 2, 3, 1)
  all_patches = x.unfold(1, kernel, stride).unfold(2, kernel, stride)
  return all_patches

2.2 Loss function

Pyramid L1 losses:

    if feats is not None:
        pyramid_loss = 0 
        for _, f in enumerate(feats):
          pyramid_loss += self.l1_loss(f, F.interpolate(images, size=f.size()[2:4], mode='bilinear', align_corners=True))
        gen_loss += pyramid_loss * self.config['losses']['pyramid_weight']
        self.add_summary(self.gen_writer, 'loss/pyramid_loss', pyramid_loss.item())

代码中的feats就是生成的不同尺度(分辨率)的图片,作者希望这些生成的不同尺度的图片和真实图片下采样的结果是很接近的,其实就是生成图片和真实图片在不同分辨率下L_1损失的和。

Adversarial training loss:

Final loss:

3. Discussion

我认为本文的创新点主要有两个:ATN的设计和金字塔式L_1损失,我觉得比较有意思的是金字塔式L_1损失,通过生成不同尺度的图片,建立更加严格的loss,使模型达到更好的效果。

4. References

【1】Zeng, Yanhong, et al. "Learning pyramid-context encoder network for high-quality image inpainting." Proceedings of the IEEE conference on computer vision and pattern recognition. 2019.

  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值