论文:《RAFT: Recurrent All-Pairs Field Transforms for Optical Flow》(ECCV 2020)
光流估计模型RAFT是在1/8分辨率的图片上以迭代的方式估计光流,根据论文中消融实验的结果,其上采样方法Convex比双线性上采样有一定的提升,因此作此笔记。
raft的上采样方法其实是一种加权相加的方法,代码中的mask其实就是权重,它是使用GRU模块的输出,再通过两层卷积得到的;另外一个很重要的函数就是F.unfold,这个其实就是使用滑动窗口提取特征然后展平的过程,具体地说比如一个张量为[N,C,H,W],窗口大小设置为K x K,忽略步长等细节,其输出就是[N,C x K x K,H x W],更具体的解释见https://viatorsun.blog.csdn.net/article/details/119940759。
通过合理地设置mask的通道数,reshape后与flow相乘在某个维度上求和,再reshape成所需的大小即可。
def upsample_flow(self, flow, mask):
""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
N, _, H, W = flow.shape
# print('mask:', mask.shape)
# mask: torch.Size([1, 576, 32, 64])
mask = mask.view(N, 1, 9, 8, 8, H, W)
mask = torch.softmax(mask, dim=2)
# print('mask:', mask.shape)
# mask: torch.Size([1, 1, 9, 8, 8, 32, 64])
# print('flow:', flow.shape)
# flow: torch.Size([1, 2, 32, 64])
up_flow = F.unfold(8 * flow, [3, 3], padding=1)
# print('up_flow_0:', up_flow.shape)
# up_flow_0: torch.Size([1, 18, 2048])
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
# print('up_flow_1:', up_flow.shape)
# up_flow_1: torch.Size([1, 2, 9, 1, 1, 32, 64])
up_flow = torch.sum(mask * up_flow, dim=2)
# print('up_flow_2:', up_flow.shape)
# up_flow_2: torch.Size([1, 2, 8, 8, 32, 64])
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
# print('up_flow_3:', up_flow.shape)
# up_flow_3: torch.Size([1, 2, 32, 8, 64, 8])
return up_flow.reshape(N, 2, 8 * H, 8 * W)