disparity estimation cost volume
学习depth/disparity estimation过程中cost volume的概念很重要但也困扰了我很久,暂时用一个例子记录一下:
假设有一对图片,已经通过feature extraction提取出了对应的特征,分别表示为refimg_fea 和targetimg_fea (为了能够简单清晰的看出cost volume的构建过程,假设特征的大小(batch_size, channels, width, height)为:[1,3,6,6]),disparity最大为4。cost_volume的构建方法来自于PSMNet。
refimg_fea = tensor = torch.randint(low=0, high=10, size=(1, 3, 6, 6))
targetimg_fea = tensor = torch.randint(low=0, high=10, size=(1, 3, 6, 6))
print('refimg_fea',refimg_fea)
print('targetimg_fea',targetimg_fea)
# print('refimg_fea.size()',refimg_fea.size())
print('refimg_fea.size()[0]',refimg_fea.size()[0]) #batch_size
print('refimg_fea.size()[1]',refimg_fea.size()[1]) #channels
print('refimg_fea.size()[2]',refimg_fea.size()[2]) #width
print('refimg_fea.size()[3]',refimg_fea.size()[3]) #height
maxdisp = 4
#matching
cost = Variable(torch.FloatTensor(refimg_fea.size()[0], refimg_fea.size()[1]*2, maxdisp, refimg_fea.size()[2], refimg_fea.size()[3]).zero_()).cuda()
for i in range(maxdisp):
if i > 0 :
cost[:, :refimg_fea.size()[1], i, :,i:] = refimg_fea[:,:,:,i:]
cost[:, refimg_fea.size()[1]:, i, :,i:] = targetimg_fea[:,:,:,:-i]
else:
cost[:, :refimg_fea.size()[1], i, :,:] = refimg_fea
cost[:, refimg_fea.size()[1]:, i, :,:] = targetimg_fea
cost = cost.contiguous().cpu()
print('cost',cost)
print('cost.size',cost.size())
cost tensor([[[[[4., 3., 6., 4., 9., 2.],
[5., 4., 6., 1., 2., 1.],
[9., 8., 1., 7., 6., 7.],
[1., 9., 7., 8., 8., 6.],
[5., 7., 1., 7., 8., 5.],
[7., 9., 2., 7., 5., 3.]],
[[0., 3., 6., 4., 9., 2.],
[0., 4., 6., 1., 2., 1.],
[0., 8., 1., 7., 6., 7.],
[0., 9., 7., 8., 8., 6.],
[0., 7., 1., 7., 8., 5.],
[0., 9., 2., 7., 5., 3.]],
[[0., 0., 6., 4., 9., 2.],
[0., 0., 6., 1., 2., 1.],
[0., 0., 1., 7., 6., 7.],
[0., 0., 7., 8., 8., 6.],
[0., 0., 1., 7., 8., 5.],
[0., 0., 2., 7., 5., 3.]],
[[0., 0., 0., 4., 9., 2.],
[0., 0., 0., 1., 2., 1.],
[0., 0., 0., 7., 6., 7.],
[0., 0., 0., 8., 8., 6.],
[0., 0., 0., 7., 8., 5.],
[0., 0., 0., 7., 5., 3.]]],
[[[2., 3., 5., 0., 8., 5.],
[5., 9., 7., 6., 8., 9.],
[0., 7., 1., 9., 7., 9.],
[0., 7., 5., 2., 9., 5.],
[3., 6., 0., 5., 0., 3.],
[0., 5., 7., 2., 3., 6.]],
[[0., 3., 5., 0., 8., 5.],
[0., 9., 7., 6., 8., 9.],
[0., 7., 1., 9., 7., 9.],
[0., 7., 5., 2., 9., 5.],
[0., 6., 0., 5., 0., 3.],
[0., 5., 7., 2., 3., 6.]],
[[0., 0., 5., 0., 8., 5.],
[0., 0., 7., 6., 8., 9.],
[0., 0., 1., 9., 7., 9.],
[0., 0., 5., 2., 9., 5.],
[0., 0., 0., 5., 0., 3.],
[0., 0., 7., 2., 3., 6.]],
[[0., 0., 0., 0., 8., 5.],
[0., 0., 0., 6., 8., 9.],
[0., 0., 0., 9., 7., 9.],
[0., 0., 0., 2., 9., 5.],
[0., 0., 0., 5., 0., 3.],
[0., 0., 0., 2., 3., 6.]]],
[[[5., 8., 9., 2., 3., 6.],
[6., 6., 6., 6., 1., 8.],
[5., 6., 8., 4., 5., 9.],
[5., 0., 0., 3., 9., 1.],
[3., 4., 1., 0., 7., 1.],
[5., 8., 1., 6., 8., 1.]],
[[0., 8., 9., 2., 3., 6.],
[0., 6., 6., 6., 1., 8.],
[0., 6., 8., 4., 5., 9.],
[0., 0., 0., 3., 9., 1.],
[0., 4., 1., 0., 7., 1.],
[0., 8., 1., 6., 8., 1.]],
[[0., 0., 9., 2., 3., 6.],
[0., 0., 6., 6., 1., 8.],
[0., 0., 8., 4., 5., 9.],
[0., 0., 0., 3., 9., 1.],
[0., 0., 1., 0., 7., 1.],
[0., 0., 1., 6., 8., 1.]],
[[0., 0., 0., 2., 3., 6.],
[0., 0., 0., 6., 1., 8.],
[0., 0., 0., 4., 5., 9.],
[0., 0., 0., 3., 9., 1.],
[0., 0., 0., 0., 7., 1.],
[0., 0., 0., 6., 8., 1.]]],
[[[0., 2., 2., 1., 3., 9.],
[0., 1., 2., 5., 6., 1.],
[9., 0., 1., 2., 5., 4.],
[5., 5., 0., 1., 8., 1.],
[6., 0., 4., 7., 9., 4.],
[7., 2., 8., 2., 4., 9.]],
[[0., 0., 2., 2., 1., 3.],
[0., 0., 1., 2., 5., 6.],
[0., 9., 0., 1., 2., 5.],
[0., 5., 5., 0., 1., 8.],
[0., 6., 0., 4., 7., 9.],
[0., 7., 2., 8., 2., 4.]],
[[0., 0., 0., 2., 2., 1.],
[0., 0., 0., 1., 2., 5.],
[0., 0., 9., 0., 1., 2.],
[0., 0., 5., 5., 0., 1.],
[0., 0., 6., 0., 4., 7.],
[0., 0., 7., 2., 8., 2.]],
[[0., 0., 0., 0., 2., 2.],
[0., 0., 0., 0., 1., 2.],
[0., 0., 0., 9., 0., 1.],
[0., 0., 0., 5., 5., 0.],
[0., 0., 0., 6., 0., 4.],
[0., 0., 0., 7., 2., 8.]]],
[[[3., 2., 1., 8., 0., 6.],
[9., 0., 4., 0., 0., 3.],
[7., 1., 9., 9., 6., 5.],
[8., 9., 8., 8., 8., 0.],
[3., 0., 1., 6., 8., 4.],
[5., 1., 5., 0., 8., 8.]],
[[0., 3., 2., 1., 8., 0.],
[0., 9., 0., 4., 0., 0.],
[0., 7., 1., 9., 9., 6.],
[0., 8., 9., 8., 8., 8.],
[0., 3., 0., 1., 6., 8.],
[0., 5., 1., 5., 0., 8.]],
[[0., 0., 3., 2., 1., 8.],
[0., 0., 9., 0., 4., 0.],
[0., 0., 7., 1., 9., 9.],
[0., 0., 8., 9., 8., 8.],
[0., 0., 3., 0., 1., 6.],
[0., 0., 5., 1., 5., 0.]],
[[0., 0., 0., 3., 2., 1.],
[0., 0., 0., 9., 0., 4.],
[0., 0., 0., 7., 1., 9.],
[0., 0., 0., 8., 9., 8.],
[0., 0., 0., 3., 0., 1.],
[0., 0., 0., 5., 1., 5.]]],
[[[6., 6., 1., 9., 9., 2.],
[0., 1., 3., 2., 4., 6.],
[5., 0., 1., 7., 0., 2.],
[5., 6., 7., 1., 9., 0.],
[7., 0., 1., 6., 2., 3.],
[1., 8., 5., 9., 4., 1.]],
[[0., 6., 6., 1., 9., 9.],
[0., 0., 1., 3., 2., 4.],
[0., 5., 0., 1., 7., 0.],
[0., 5., 6., 7., 1., 9.],
[0., 7., 0., 1., 6., 2.],
[0., 1., 8., 5., 9., 4.]],
[[0., 0., 6., 6., 1., 9.],
[0., 0., 0., 1., 3., 2.],
[0., 0., 5., 0., 1., 7.],
[0., 0., 5., 6., 7., 1.],
[0., 0., 7., 0., 1., 6.],
[0., 0., 1., 8., 5., 9.]],
[[0., 0., 0., 6., 6., 1.],
[0., 0., 0., 0., 1., 3.],
[0., 0., 0., 5., 0., 1.],
[0., 0., 0., 5., 6., 7.],
[0., 0., 0., 7., 0., 1.],
[0., 0., 0., 1., 8., 5.]]]]])
cost.size torch.Size([1, 6, 4, 6, 6])
从结果可以看出,cost_volume是一个[1,6,4,6,6]的tensor,可以解释为:一个batch_size中有6个(左右2个特征,每个特征通道数为3,摞在一起则为6) [4,6,6]的三维张量,其中,前三个 [4,6,6]的三维张量 对应的是左边图像的特征的三个通道,而后三个 [4,6,6]的三维张量 对应的是右边图像的特征的三个通道。当disparity为0,那么cost volume就直接等于左右两个图像的特征,当disparity为1,那么cost volume就为左边图像第一列到最后一列(从第0列开始数),右边图像的最后一列到倒数第一列(同样从第0列开始数)…一直计算到当disparity为4。