torch.nn.functional.grid_sample() 注意点
用法: 主要用于采样,一般是使用bilinear根据grid的坐标采样
F.grid_sample(img, grid, align_corners=True)
img 是采样的空间,grid是生成的网格坐标。
grid通常由torch.meshgrid()生成,且要映射到(-1,1)之间,如:
dx = torch.linspace(-1,1, 9)
dy = torch.linspace(-1, 1,7)
coords = torch.stack(torch.meshgrid(dy, dx), axis=-1) #[dy*dx*2]
输入输出情况:
假如是4D 的input:
img.shape : [B,C,H_in,W_in]
grid.shape: [B,H_out,W_out,2]
out: [B,C,H_out,W_out]
细节:
1.为什么meshgrid生成坐标的时候,stack成coords时需要逆序(第一层是y,第二层是x)?
Ans:采样的时候,在img上取点,坐标是根据grid来的,grid[:,:,0]是W维度的坐标,grid[:,:,1]是H维度的坐标,所以这个地方需要注意,是反过来的
2.grid的形状仅仅影响output的形状,直接决定取点的还是坐标,所以尤其要注意grid坐标叠。