遇到代码里的连续三行分别为标题所述,人没了
需要结合上一篇一起看:python torch 多个矩阵作为矩阵索引
1.使用到的矩阵说明
(省略号处为1024维)
batch_indices=tensor([[0, 0, 0, ..., 0, 0, 0],
[1, 1, 1, ..., 1, 1, 1],
[2, 2, 2, ..., 2, 2, 2],
[3, 3, 3, ..., 3, 3, 3]], device='cuda:0')
row_indices=tensor([[ 0, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1023]], device='cuda:0')
itself_indices=tensor([[ 0, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1023]], device='cuda:0')
group_idx=tensor([[[ 0, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1023],
...,
[ 0, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1023]],
[[ 0, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1023],
...,
[ 0, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1023]],
[[ 0, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1023],
...,
[ 0, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1023]],
[[ 0, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1023],
...,
[ 0, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1023]]], device='cuda:0')
temp1=group_idx[batch_indices, row_indices, itself_indices]
temp1=tensor([[ 0, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1023]], device='cuda:0')
2、常数赋值给tensor
group_idx[batch_indices, row_indices, itself_indices] = 1024
group_idx=tensor([[[1024, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1024, 2, ..., 1021, 1022, 1023],
[ 0, 1, 1024, ..., 1021, 1022, 1023],
...,
[ 0, 1, 2, ..., 1024, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1024, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1024]],
[[1024, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1024, 2, ..., 1021, 1022, 1023],
[ 0, 1, 1024, ..., 1021, 1022, 1023],
...,
[ 0, 1, 2, ..., 1024, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1024, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1024]],
[[1024, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1024, 2, ..., 1021, 1022, 1023],
[ 0, 1, 1024, ..., 1021, 1022, 1023],
...,
[ 0, 1, 2, ..., 1024, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1024, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1024]],
[[1024, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1024, 2, ..., 1021, 1022, 1023],
[ 0, 1, 1024, ..., 1021, 1022, 1023],
...,
[ 0, 1, 2, ..., 1024, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1024, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1024]]], device='cuda:0')
3.常数和tensor比较大小
//相同两个矩阵的距离矩阵,对角线元素为相同点间的距离为0,矩阵沿主对角线对称,因为第2个点和第1个点的距离等于第1个点和第2个点的距离,也即sqrdists[0][1][0]=sqrdists[0][0][1].sqrdists[0]代表第一个batch.
//sqrdists:4*1024*1024
sqrdists=tensor([[[ 0.0000e+00, 1.1386e-01, 6.3624e-01, ..., 1.3848e+00,
5.5936e-01, 1.0839e+00],
[ 1.1386e-01, 0.0000e+00, 8.9093e-01, ..., 2.2411e+00,
7.8311e-01, 1.8543e+00],
[ 6.3624e-01, 8.9093e-01, -5.9605e-08, ..., 1.5017e+00,
1.5127e-02, 1.2706e+00],
...,
[ 1.3848e+00, 2.2411e+00, 1.5017e+00, ..., -5.9605e-08,
1.6157e+00, 1.8451e-02],
[ 5.5936e-01, 7.8311e-01, 1.5127e-02, ..., 1.6157e+00,
0.0000e+00, 1.3624e+00],
[ 1.0839e+00, 1.8543e+00, 1.2706e+00, ..., 1.8451e-02,
1.3624e+00, 5.9605e-08]],
[[ 0.0000e+00, 1.5277e+00, 6.1708e-01, ..., 9.5098e-01,
2.1011e-01, 1.2949e-02],
[ 1.5277e+00, 0.0000e+00, 3.4608e+00, ..., 9.0734e-02,
1.2185e+00, 1.2821e+00],
[ 6.1708e-01, 3.4608e+00, 1.1921e-07, ..., 2.7502e+00,
1.5215e+00, 7.9318e-01],
...,
[ 9.5098e-01, 9.0734e-02, 2.7502e+00, ..., 0.0000e+00,
6.4702e-01, 7.4993e-01],
[ 2.1011e-01, 1.2185e+00, 1.5215e+00, ..., 6.4702e-01,
0.0000e+00, 1.4260e-01],
[ 1.2949e-02, 1.2821e+00, 7.9318e-01, ..., 7.4993e-01,
1.4260e-01, 3.7253e-09]],
[[-1.1921e-07, 6.5509e-01, 5.0194e-01, ..., 1.9291e-01,
1.1051e+00, 5.2810e-01],
[ 6.5509e-01, 0.0000e+00, 1.4901e-01, ..., 1.8733e-01,
1.7176e-01, 1.5647e-01],
[ 5.0194e-01, 1.4901e-01, 0.0000e+00, ..., 1.9878e-01,
1.3537e-01, 3.1886e-01],
...,
[ 1.9291e-01, 1.8733e-01, 1.9878e-01, ..., -2.9802e-08,
4.8584e-01, 3.0713e-01],
[ 1.1051e+00, 1.7176e-01, 1.3537e-01, ..., 4.8584e-01,
2.9802e-08, 5.2555e-01],
[ 5.2810e-01, 1.5647e-01, 3.1886e-01, ..., 3.0713e-01,
5.2555e-01, 0.0000e+00]],
[[-1.1921e-07, 1.5104e+00, 3.3076e+00, ..., 2.0143e-01,
1.0395e-01, 2.4303e+00],
[ 1.5104e+00, 0.0000e+00, 8.1933e-01, ..., 7.9454e-01,
1.2310e+00, 7.8964e-01],
[ 3.3076e+00, 8.1933e-01, 0.0000e+00, ..., 1.9282e+00,
2.4735e+00, 4.4241e-01],
...,
[ 2.0143e-01, 7.9454e-01, 1.9282e+00, ..., -5.9605e-08,
9.0249e-02, 1.4813e+00],
[ 1.0395e-01, 1.2310e+00, 2.4735e+00, ..., 9.0249e-02,
0.0000e+00, 1.6600e+00],
[ 2.4303e+00, 7.8964e-01, 4.4241e-01, ..., 1.4813e+00,
1.6600e+00, -5.9605e-08]]], device='cuda:0')
radius=0.3
temp2=sqrdists>radius ** 2
temp2=tensor([[[False, True, True, ..., True, True, True],
[ True, False, True, ..., True, True, True],
[ True, True, False, ..., True, False, True],
...,
[ True, True, True, ..., False, True, False],
[ True, True, False, ..., True, False, True],
[ True, True, True, ..., False, True, False]],
[[False, True, True, ..., True, True, False],
[ True, False, True, ..., True, True, True],
[ True, True, False, ..., True, True, True],
...,
[ True, True, True, ..., False, True, True],
[ True, True, True, ..., True, False, True],
[False, True, True, ..., True, True, False]],
[[False, True, True, ..., True, True, True],
[ True, False, True, ..., True, True, True],
[ True, True, False, ..., True, True, True],
...,
[ True, True, True, ..., False, True, True],
[ True, True, True, ..., True, False, True],
[ True, True, True, ..., True, True, False]],
[[False, True, True, ..., True, True, True],
[ True, False, True, ..., True, True, True],
[ True, True, False, ..., True, True, True],
...,
[ True, True, True, ..., False, True, True],
[ True, True, True, ..., True, False, True],
[ True, True, True, ..., True, True, False]]], device='cuda:0')
4. 常数和tensor比较大小后作为tensor索引
temp3=group_idx[sqrdists > radius ** 2]
temp3=tensor([ 1, 2, 3, ..., 1020, 1021, 1022], device='cuda:0')