torch.pairwise_distance(): 计算特征图之间的像素级欧氏距离

torch.pairwise_distance(x1, x2)

这个API可用于计算特征图之间的像素级的距离,输入x1维度为[N,C,H,W],输入x2的维度为[M,C,H,W]。可以通过torch.pairwise_distance(x1, x2)来计算得到像素级距离。

其中要求N==M or N==1 or M==1

这个API我在官方文档没有搜到,而是在通过一篇文章的github源码偶然得知,通过自己的尝试从而总结,如有不全面,还请见谅。

使用示例1

已有模板特征T,其维度为[1,C,H,W],想要计算特征图F(维度为[1, C, H, W])与模板特征之间每个像素点(共HxW个像素)的距离。代码示例如下:

t = torch.randn(1,3,3,3)
f = torch.randn(4,3,3,3)

dist_matrix = torch.pairwise_distance(t, f)
print(dist_matrix.shape)
# torch.Size([4, 3, 3])

使用示例2

已有像素级模板特征T,其维度为[1,C,1,1],想要计算特征图F(维度为[1, C, H, W])中每个像素(共HxW个像素)与模板像素特征的距离。代码示例如下:

	t = torch.randn(1,3,1,1)
	f = torch.randn(4,3,3,3)
	
	dist_matrix = torch.pairwise_distance(t, f)
	print(dist_matrix.shape)
	# torch.Size([4, 3, 3])

还有许多不同的用法,这里不再叙述

正确性检查

因为没有找到对应的官方文档,因此自己写了一些检测程序。代码如下:

程序1

    x = torch.from_numpy(np.array([1,1,1])).float().view(-1,3).unsqueeze(-1).unsqueeze(-1)
    y = torch.from_numpy(np.array([[[3,3,3],[1,1,1]],
                                 [[1,1,1],[1,1,1]]])).float().permute(2,0,1).unsqueeze(0)
    # print(x.shape,'\n',x)
    # print(y.shape,'\n',y)
    dist_matrix = torch.pairwise_distance(x, y)
    print(dist_matrix)

构造x和y,维度上:x为[1,3,1,1],y为[1,3,2,2]。其中y[0,0]与模板像素差距比较大,其它像素位置上与模板像素相同。

输出:

torch.Size([1, 3, 1, 1]) 	# x.shape
 tensor([[[[1.]],	# x

         [[1.]],

         [[1.]]]])
         
torch.Size([1, 3, 2, 2]) 	# y.shape
 tensor([[[[3., 1.],	# y
          [1., 1.]],

         [[3., 1.],
          [1., 1.]],

         [[3., 1.],
          [1., 1.]]]])
          
tensor([[[3.4641e+00, 1.7321e-06],	# dist_matrix
         [1.7321e-06, 1.7321e-06]]])

可以看到除了[0,0]位置上值比较大,其他都接近于0.

程序2

	x = torch.from_numpy(np.array([[1,1,1], [3,3,3]])).float().view(-1,3).unsqueeze(-1).unsqueeze(-1)
    y = torch.from_numpy(np.array([[[3,3,3],[1,1,1]],
                                 [[1,1,1],[1,1,1]]])).float().permute(2,0,1).unsqueeze(0)
    print(x.shape,'\n',x)
    print(y.shape,'\n',y)
    dist_matrix = torch.pairwise_distance(x, y)
    print(dist_matrix)

构造x和y,维度上:x为[2,3,1,1],y为[1,3,2,2]。其中y[0,0]与模板像素特征[0]差距比较大,其它像素位置上与模板像素[0]相同,y[0,0]与模板像素特征[1]相同,其它像素位置上与模板像素[1]差距较大。

torch.Size([2, 3, 1, 1]) 	# x.shape
 tensor([[[[1.]],	# x

         [[1.]],

         [[1.]]],


        [[[3.]],

         [[3.]],

         [[3.]]]])
         
torch.Size([1, 3, 2, 2]) 	# y.shape
 tensor([[[[3., 1.],	# y
          [1., 1.]],

         [[3., 1.],
          [1., 1.]],

         [[3., 1.],
          [1., 1.]]]])
          
tensor([[[3.4641e+00, 1.7321e-06],	# dist_matrix
         [1.7321e-06, 1.7321e-06]],

        [[1.7321e-06, 3.4641e+00],
         [3.4641e+00, 3.4641e+00]]])

可以看到distance_matrix[0]除了[0,0]位置上值比较大,其他都接近于0,而distance_matrix[1]的[0,0]位置上为0。

  • 14
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 20
    评论
评论 20
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值