BG-Net双目立体匹配_细节笔记

原文:《Bilateral Grid Learning for Stereo Matching Networks》(CVPR2021)

CVPR 2021 Open Access Repository (thecvf.com)

开源代码:3DCVdeveloper/BGNet: Xuyuhua works (github.com)​​​​​​

此网络是两年前发表的,但在今天的KITTI2015榜单上来看,也属于速度很快、精度尚可的网络。这篇笔记主要记录我学习其源码遇到的一些问题和我自己的理解。

1.BGnet & BGnet+

开源项目中有两个实现:BGnet/BGnet+,其中Plus版本增加了一个视差优化模块,定义在models/submodules2d.py文件中的HourglassRefinement。

类似《Learning for Disparity Estimation through Feature Constancy》中的残差优化操作,但是没有利用特征图或代价体,是直接根据初步视差和右视图重建左视图,将重建误差、原始左视图、初始视差cat到一起,经过一个2D的Hourglass模块得到视差残差,再加上初始视差作为最终输出。即:

#重建左视图
warped_right = disp_warp(right_img, disp)[0]  # [B, C, H, W]
#计算重建误差
error = warped_right - left_img  # [B, C, H, W]
#链接重建误差和左视图
concat1 = torch.cat((error, left_img), dim=1)  # [B, 6, H, W]
conv1 = self.conv1(concat1)  # [B, 16, H, W]
conv2 = self.conv2(disp)  # [B, 16, H, W]
#链接初始视差
x = torch.cat((conv1, conv2), dim=1)  # [B, 32, H, W]
#优化残差模块
residual_disp = Hourglass(x) 
#输出=输入+残差
disp_out = F.relu(disp + residual_disp, inplace=True)  # [B, 1, H, W]

这个视差优化模块目测尤其提升了视差图边缘质量(右图为PLUS版本)。

 

2.slicing模块(grid_sample操作)中的切片与插值

这一模块是对使用3D卷积聚合后的cost volume进行某种采样操作,得到最终的代价立方体供后续softmax和视差变换使用。

其中cost volume的构造类似Gwcnet,但只用了“分组相关”一种方式,没有连接concat  volume代价。cost volume格式[N,C,D,H,W],其中C为分组数量,可以理解成后续的通道数量。代价聚合使用3D的Hourglass模块,在DHW三个维度上进行3D卷积,输出聚合后代价coeffs,在输出前最后一步将coeffs视差和通道维度互换,调整为了[N,D,C,H,W]。所以slicing模块即对格式为[N,D,C,H,W]的代价立方体进行采样,具体实现如下。

#slice函数
class Slice(SubModule):
    def __init__(self):
        super(Slice, self).__init__()
    
    def forward(self, bilateral_grid, wg, hg, guidemap): 
        guidemap = guidemap.permute(0,2,3,1).contiguous() #[B,C,H,W]-> [B,H,W,C]
        guidemap_guide = torch.cat([wg, hg, guidemap], dim=3).unsqueeze(1) # Nx1xHxWx3        
        coeff = F.grid_sample(bilateral_grid, guidemap_guide,align_corners =False)
        return coeff.squeeze(2) #[B,1,H,W]
        #对coeffs进行slice操作
        list_coeffs = torch.split(coeffs,1,dim = 1)
        device = list_coeffs[0].get_device()

        N, _, H, W = guide.shape
        #[H,W]
        hg, wg = torch.meshgrid([torch.arange(0, H), torch.arange(0, W)]) # [0,511] HxW
        if device >= 0:
            hg = hg.to(device)
            wg = wg.to(device)
        #[B,H,W,1]
        hg = hg.float().repeat(N, 1, 1).unsqueeze(3) / (H-1) * 2 - 1 # norm to [-1,1] NxHxWx1
        wg = wg.float().repeat(N, 1, 1).unsqueeze(3) / (W-1) * 2 - 1 # norm to [-1,1] NxHxWx1
        slice_dict = []
        # torch.cuda.synchronize()
        # start = time.time()
        for i in range(25):
            slice_dict.append(self.slice(list_coeffs[i], wg, hg, guide)) #[B,1,H,W]

其中hg\wg是grid_sample函数中输出采样的H和W坐标,C坐标由左视图产生的引导图guide确定。我的理解就是在guide图中相似的像素,就让他们在同一通道C上采样视差代价,并遍历25个视差层级,由于guide图是原始左图的某种特征,就实现了某种引导滤波(保边滤波?)的效果。

我不理解的是为什么要分维度遍历进行grid_sample操作,grid_sample本身是可以处理多通道数据,如果直接写成以下这样(将D维度按通道维度处理),推理是没有问题的,不知道训练会不会有一些问题?

        device = coeffs.get_device()
        N, _, H, W = guide.shape
        #[H,W]
        hg, wg = torch.meshgrid([torch.arange(0, H), torch.arange(0, W)]) # [0,511] HxW
        if device >= 0:
            hg = hg.to(device)
            wg = wg.to(device)
        #[B,H,W,1]
        hg = hg.float().repeat(N, 1, 1).unsqueeze(3) / (H-1) * 2 - 1 # norm to [-1,1] NxHxWx1
        wg = wg.float().repeat(N, 1, 1).unsqueeze(3) / (W-1) * 2 - 1 # norm to [-1,1] NxHxWx1

        guide = guide.permute(0,2,3,1).contiguous() #[B,C,H,W]-> [B,H,W,C]
        guidemap_guide = torch.cat([wg, hg, guide], dim=3).unsqueeze(1) # Nx1xHxWx3   
        coeff = F.grid_sample(coeffs, guidemap_guide,align_corners =False) # [B,D,1,H,W]

3.slice后softmax前的操作

开源代码对slice_dict块有如下的操作,生成final_cost_volume后才进行softmax计算,看到github的issue里有人问这一块是干什么的:

        wa  = wa.view(1,-1,1,1)
        wb  = wb.view(1,-1,1,1)
        wa = wa.to(device)
        wb = wb.to(device)
        wa = wa.float()
        wb = wb.float()

        slice_dict_a = []
        slice_dict_b = []
        for i in range(97):
            inx_a = i//4
            inx_b = inx_a + 1
            inx_b  = min(inx_b,24)
            slice_dict_a.append(slice_dict[inx_a])
            slice_dict_b.append(slice_dict[inx_b])
            
        final_cost_volume = wa * torch.cat(slice_dict_a,dim = 1) + wb * torch.cat(slice_dict_b,dim = 1)

这一块是将[N,25,H,W]的cost volume上采样为[N,97,H,W],采样方法就是朴实无华的线性插值,只不过是作者写死了线性插值函数。为什么要上采样呢,因为在组相关构建cost volume时特征图尺寸是原图1/8,在视差为0-24尺度下构建代价立方,那么实际可处理视差就应该是0-192(24*8)。这里先生成了1/2尺寸的视差图,后续再将实际的视差图双线性插值为原尺寸。1/2尺寸视差为0-96共97层,所以就要将cost volume上采样为[N,97,H,W]啦。

我不太理解的是为什么要手写双线性插值,为什么不直接用interpolate函数呢?当然用interpolate函数就要交换D\C维度,然后做三线性插值。是因为三线性插值开销太大?或者是交换Channel维度会在训练时出现问题?反正改成下面这样也是能跑推理没有问题,不知道训练有没有问题。

        #[N,D(25),H,W]->[N,D(97),H,W]
        final_cost_volume = F.interpolate(
                                        input = coeff.permute(0,2,1,3,4).contiguous(),
                                        size = (97,H,W),
                                        mode = 'trilinear',
                                        align_corners=True).squeeze(1)

4.推理速度

将BGNET+部署到Xavier nx的上,使用TensorRT的fp16精度进行推理,推理一次大约85ms,如果用INT8量化会更快点,但估计这种数值型输出,精度损失也会更大。好奇还有什么加速方法,感觉剪枝什么的也快不了多少了。ZED相机号称是深度学习方法、在Xavier nx上可以跑到100Hz,如果端到端DeepLearning的立体匹配能做到吗,还是传统方法做初步视差图提议,然后深度学习优化?或者针对网络定制化实现算子会有数量级上的提速?

Anything else:pytorch不知道是从哪个版本开始,squeeze函数在转onnx时会加一个分支判断,在转TensorRT时这个判断会因为涉及比较不同的数据类型而报错,所以部署前最好把所有squeeze函数改成view函数或者别的。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值