原文:《Bilateral Grid Learning for Stereo Matching Networks》(CVPR2021)
CVPR 2021 Open Access Repository (thecvf.com)
开源代码:3DCVdeveloper/BGNet: Xuyuhua works (github.com)
此网络是两年前发表的,但在今天的KITTI2015榜单上来看,也属于速度很快、精度尚可的网络。这篇笔记主要记录我学习其源码遇到的一些问题和我自己的理解。
1.BGnet & BGnet+
开源项目中有两个实现:BGnet/BGnet+,其中Plus版本增加了一个视差优化模块,定义在models/submodules2d.py文件中的HourglassRefinement。
类似《Learning for Disparity Estimation through Feature Constancy》中的残差优化操作,但是没有利用特征图或代价体,是直接根据初步视差和右视图重建左视图,将重建误差、原始左视图、初始视差cat到一起,经过一个2D的Hourglass模块得到视差残差,再加上初始视差作为最终输出。即:
#重建左视图
warped_right = disp_warp(right_img, disp)[0] # [B, C, H, W]
#计算重建误差
error = warped_right - left_img # [B, C, H, W]
#链接重建误差和左视图
concat1 = torch.cat((error, left_img), dim=1) # [B, 6, H, W]
conv1 = self.conv1(concat1) # [B, 16, H, W]
conv2 = self.conv2(disp) # [B, 16, H, W]
#链接初始视差
x = torch.cat((conv1, conv2), dim=1) # [B, 32, H, W]
#优化残差模块
residual_disp = Hourglass(x)
#输出=输入+残差
disp_out = F.relu(disp + residual_disp, inplace=True) # [B, 1, H, W]
这个视差优化模块目测尤其提升了视差图边缘质量(右图为PLUS版本)。
2.slicing模块(grid_sample操作)中的切片与插值
这一模块是对使用3D卷积聚合后的cost volume进行某种采样操作,得到最终的代价立方体供后续softmax和视差变换使用。
其中cost volume的构造类似Gwcnet,但只用了“分组相关”一种方式,没有连接concat volume代价。cost volume格式[N,C,D,H,W],其中C为分组数量,可以理解成后续的通道数量。代价聚合使用3D的Hourglass模块,在DHW三个维度上进行3D卷积,输出聚合后代价coeffs,在输出前最后一步将coeffs视差和通道维度互换,调整为了[N,D,C,H,W]。所以slicing模块即对格式为[N,D,C,H,W]的代价立方体进行采样,具体实现如下。
#slice函数
class Slice(SubModule):
def __init__(self):
super(Slice, self).__init__()
def forward(self, bilateral_grid, wg, hg, guidemap):
guidemap = guidemap.permute(0,2,3,1).contiguous() #[B,C,H,W]-> [B,H,W,C]
guidemap_guide = torch.cat([wg, hg, guidemap], dim=3).unsqueeze(1) # Nx1xHxWx3
coeff = F.grid_sample(bilateral_grid, guidemap_guide,align_corners =False)
return coeff.squeeze(2) #[B,1,H,W]
#对coeffs进行slice操作
list_coeffs = torch.split(coeffs,1,dim = 1)
device = list_coeffs[0].get_device()
N, _, H, W = guide.shape
#[H,W]
hg, wg = torch.meshgrid([torch.arange(0, H), torch.arange(0, W)]) # [0,511] HxW
if device >= 0:
hg = hg.to(device)
wg = wg.to(device)
#[B,H,W,1]
hg = hg.float().repeat(N, 1, 1).unsqueeze(3) / (H-1) * 2 - 1 # norm to [-1,1] NxHxWx1
wg = wg.float().repeat(N, 1, 1).unsqueeze(3) / (W-1) * 2 - 1 # norm to [-1,1] NxHxWx1
slice_dict = []
# torch.cuda.synchronize()
# start = time.time()
for i in range(25):
slice_dict.append(self.slice(list_coeffs[i], wg, hg, guide)) #[B,1,H,W]
其中hg\wg是grid_sample函数中输出采样的H和W坐标,C坐标由左视图产生的引导图guide确定。我的理解就是在guide图中相似的像素,就让他们在同一通道C上采样视差代价,并遍历25个视差层级,由于guide图是原始左图的某种特征,就实现了某种引导滤波(保边滤波?)的效果。
我不理解的是为什么要分维度遍历进行grid_sample操作,grid_sample本身是可以处理多通道数据,如果直接写成以下这样(将D维度按通道维度处理),推理是没有问题的,不知道训练会不会有一些问题?
device = coeffs.get_device()
N, _, H, W = guide.shape
#[H,W]
hg, wg = torch.meshgrid([torch.arange(0, H), torch.arange(0, W)]) # [0,511] HxW
if device >= 0:
hg = hg.to(device)
wg = wg.to(device)
#[B,H,W,1]
hg = hg.float().repeat(N, 1, 1).unsqueeze(3) / (H-1) * 2 - 1 # norm to [-1,1] NxHxWx1
wg = wg.float().repeat(N, 1, 1).unsqueeze(3) / (W-1) * 2 - 1 # norm to [-1,1] NxHxWx1
guide = guide.permute(0,2,3,1).contiguous() #[B,C,H,W]-> [B,H,W,C]
guidemap_guide = torch.cat([wg, hg, guide], dim=3).unsqueeze(1) # Nx1xHxWx3
coeff = F.grid_sample(coeffs, guidemap_guide,align_corners =False) # [B,D,1,H,W]
3.slice后softmax前的操作
开源代码对slice_dict块有如下的操作,生成final_cost_volume后才进行softmax计算,看到github的issue里有人问这一块是干什么的:
wa = wa.view(1,-1,1,1)
wb = wb.view(1,-1,1,1)
wa = wa.to(device)
wb = wb.to(device)
wa = wa.float()
wb = wb.float()
slice_dict_a = []
slice_dict_b = []
for i in range(97):
inx_a = i//4
inx_b = inx_a + 1
inx_b = min(inx_b,24)
slice_dict_a.append(slice_dict[inx_a])
slice_dict_b.append(slice_dict[inx_b])
final_cost_volume = wa * torch.cat(slice_dict_a,dim = 1) + wb * torch.cat(slice_dict_b,dim = 1)
这一块是将[N,25,H,W]的cost volume上采样为[N,97,H,W],采样方法就是朴实无华的线性插值,只不过是作者写死了线性插值函数。为什么要上采样呢,因为在组相关构建cost volume时特征图尺寸是原图1/8,在视差为0-24尺度下构建代价立方,那么实际可处理视差就应该是0-192(24*8)。这里先生成了1/2尺寸的视差图,后续再将实际的视差图双线性插值为原尺寸。1/2尺寸视差为0-96共97层,所以就要将cost volume上采样为[N,97,H,W]啦。
我不太理解的是为什么要手写双线性插值,为什么不直接用interpolate函数呢?当然用interpolate函数就要交换D\C维度,然后做三线性插值。是因为三线性插值开销太大?或者是交换Channel维度会在训练时出现问题?反正改成下面这样也是能跑推理没有问题,不知道训练有没有问题。
#[N,D(25),H,W]->[N,D(97),H,W]
final_cost_volume = F.interpolate(
input = coeff.permute(0,2,1,3,4).contiguous(),
size = (97,H,W),
mode = 'trilinear',
align_corners=True).squeeze(1)
4.推理速度
将BGNET+部署到Xavier nx的上,使用TensorRT的fp16精度进行推理,推理一次大约85ms,如果用INT8量化会更快点,但估计这种数值型输出,精度损失也会更大。好奇还有什么加速方法,感觉剪枝什么的也快不了多少了。ZED相机号称是深度学习方法、在Xavier nx上可以跑到100Hz,如果端到端DeepLearning的立体匹配能做到吗,还是传统方法做初步视差图提议,然后深度学习优化?或者针对网络定制化实现算子会有数量级上的提速?
Anything else:pytorch不知道是从哪个版本开始,squeeze函数在转onnx时会加一个分支判断,在转TensorRT时这个判断会因为涉及比较不同的数据类型而报错,所以部署前最好把所有squeeze函数改成view函数或者别的。