Pytorch Interpolate解析

PyTorch Interpolate align corners解析

PyTorch的interpolate可以用来upsample和downsample tensor。
torch.nn.functional.interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False)

这篇文章主要讨论align_cornerTrueFalse时interpolate是如何做上下采样的。

align_corner = True

请添加图片描述

解释一下上图。

蓝色的圆点表示的是原本的2D像素的中心点(并没有把所有像素点都表示出来,可以脑补一下,一共应该是16个点,4x4的tensor)。
绿色的圆点是下采样到2x2大小的tensor后的像素中心点,一共四个。
黄色的圆点是上采样到8x8大小的tensor后的像素中心点,一共64个,没有全部标出来。
虚线方框表示的是一个像素的大小。比如下采样时,像素点变少了,所以像素方块变大了(蓝变绿)。上采样,像素点变多,像素方块变小(蓝变黄)。

align_corner=True的意义

  1. corner像素:
    上采样和下采样后,corner像素的值不变。
    无论上采样还是下采样,output的corner像素和input的corner像素的中心点是对齐的。
    四个角的像素点其实是蓝绿黄三个像素点重合在一起了。
  2. 其他像素:
    corner像素值是不变的所以很容易知道,不需要计算。其他像素用插值计算(根据选择的mode)。
    比如上采样到8x8tensor后,我们知道第一行有8个点(以第一行1D tensor为例),两端点的值已知,中间六个黄色点的位置可以算出来。假设端点位置是0, 1。那么中间分成了7份,所以位置分别是1/7, 2/7…
    中间两个蓝色点的位置也可以推算出来。中间分成了3份,所以中间两蓝点位置为1/3和2/3
    根据插值可以用蓝色点的值算出黄色点的值是什么。(reminder:蓝色点是input像素中心点,黄色点是上采样output像素的中心点)。

align_corner = False

请添加图片描述

align_corner=False的意义

不align corner的情况下,corner input像素的边界和corner output的像素的边界对齐(align corner时是corner像素中心点对齐)。
可以看到上图黄色像素以及绿色像素(output)的边界和蓝色像素(input)的边界是对齐的。
output像素的值,可以根据像素中心点的位置插值得出。
对于超出像素点的像素,需要用边界值padding,原话是:uses edge value padding for out-of-boundary values

代码

x = torch.arange(1, 1 * 1 * 4 * 4 + 1).view(1, 1, 4, 4).float()
print(x)
print_line()

print("upsample, align_corner = True")
output
  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值