PyTorch Interpolate align corners解析
PyTorch的interpolate可以用来upsample和downsample tensor。
torch.nn.functional.interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False)
这篇文章主要讨论align_corner
为True
和False
时interpolate是如何做上下采样的。
align_corner = True
解释一下上图。
蓝色的圆点表示的是原本的2D像素的中心点(并没有把所有像素点都表示出来,可以脑补一下,一共应该是16个点,4x4的tensor)。
绿色的圆点是下采样到2x2大小的tensor后的像素中心点,一共四个。
黄色的圆点是上采样到8x8大小的tensor后的像素中心点,一共64个,没有全部标出来。
虚线方框表示的是一个像素的大小。比如下采样时,像素点变少了,所以像素方块变大了(蓝变绿)。上采样,像素点变多,像素方块变小(蓝变黄)。
align_corner=True
的意义
- corner像素:
上采样和下采样后,corner像素的值不变。
无论上采样还是下采样,output的corner像素和input的corner像素的中心点是对齐的。
四个角的像素点其实是蓝绿黄三个像素点重合在一起了。 - 其他像素:
corner像素值是不变的所以很容易知道,不需要计算。其他像素用插值计算(根据选择的mode)。
比如上采样到8x8tensor后,我们知道第一行有8个点(以第一行1D tensor为例),两端点的值已知,中间六个黄色点的位置可以算出来。假设端点位置是0, 1。那么中间分成了7份,所以位置分别是1/7, 2/7…
中间两个蓝色点的位置也可以推算出来。中间分成了3份,所以中间两蓝点位置为1/3和2/3
根据插值可以用蓝色点的值算出黄色点的值是什么。(reminder:蓝色点是input像素中心点,黄色点是上采样output像素的中心点)。
align_corner = False
align_corner=False
的意义
不align corner的情况下,corner input像素的边界和corner output的像素的边界对齐(align corner时是corner像素中心点对齐)。
可以看到上图黄色像素以及绿色像素(output)的边界和蓝色像素(input)的边界是对齐的。
output像素的值,可以根据像素中心点的位置插值得出。
对于超出像素点的像素,需要用边界值padding,原话是:uses edge value padding for out-of-boundary values
代码
x = torch.arange(1, 1 * 1 * 4 * 4 + 1).view(1, 1, 4, 4).float()
print(x)
print_line()
print("upsample, align_corner = True")
output