python 常数赋值给tensor、常数和tensor比较大小、常数和tensor比较大小后作为tensor索引

遇到代码里的连续三行分别为标题所述,人没了
需要结合上一篇一起看:python torch 多个矩阵作为矩阵索引

1.使用到的矩阵说明

(省略号处为1024维)

batch_indices=tensor([[0, 0, 0,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        [2, 2, 2,  ..., 2, 2, 2],
        [3, 3, 3,  ..., 3, 3, 3]], device='cuda:0')
row_indices=tensor([[   0,    1,    2,  ..., 1021, 1022, 1023],
        [   0,    1,    2,  ..., 1021, 1022, 1023],
        [   0,    1,    2,  ..., 1021, 1022, 1023],
        [   0,    1,    2,  ..., 1021, 1022, 1023]], device='cuda:0')
itself_indices=tensor([[   0,    1,    2,  ..., 1021, 1022, 1023],
        [   0,    1,    2,  ..., 1021, 1022, 1023],
        [   0,    1,    2,  ..., 1021, 1022, 1023],
        [   0,    1,    2,  ..., 1021, 1022, 1023]], device='cuda:0')
group_idx=tensor([[[   0,    1,    2,  ..., 1021, 1022, 1023],
         [   0,    1,    2,  ..., 1021, 1022, 1023],
         [   0,    1,    2,  ..., 1021, 1022, 1023],
         ...,
         [   0,    1,    2,  ..., 1021, 1022, 1023],
         [   0,    1,    2,  ..., 1021, 1022, 1023],
         [   0,    1,    2,  ..., 1021, 1022, 1023]],

        [[   0,    1,    2,  ..., 1021, 1022, 1023],
         [   0,    1,    2,  ..., 1021, 1022, 1023],
         [   0,    1,    2,  ..., 1021, 1022, 1023],
         ...,
         [   0,    1,    2,  ..., 1021, 1022, 1023],
         [   0,    1,    2,  ..., 1021, 1022, 1023],
         [   0,    1,    2,  ..., 1021, 1022, 1023]],

        [[   0,    1,    2,  ..., 1021, 1022, 1023],
         [   0,    1,    2,  ..., 1021, 1022, 1023],
         [   0,    1,    2,  ..., 1021, 1022, 1023],
         ...,
         [   0,    1,    2,  ..., 1021, 1022, 1023],
         [   0,    1,    2,  ..., 1021, 1022, 1023],
         [   0,    1,    2,  ..., 1021, 1022, 1023]],

        [[   0,    1,    2,  ..., 1021, 1022, 1023],
         [   0,    1,    2,  ..., 1021, 1022, 1023],
         [   0,    1,    2,  ..., 1021, 1022, 1023],
         ...,
         [   0,    1,    2,  ..., 1021, 1022, 1023],
         [   0,    1,    2,  ..., 1021, 1022, 1023],
         [   0,    1,    2,  ..., 1021, 1022, 1023]]], device='cuda:0')
temp1=group_idx[batch_indices, row_indices, itself_indices]

temp1=tensor([[   0,    1,    2,  ..., 1021, 1022, 1023],
        [   0,    1,    2,  ..., 1021, 1022, 1023],
        [   0,    1,    2,  ..., 1021, 1022, 1023],
        [   0,    1,    2,  ..., 1021, 1022, 1023]], device='cuda:0')

2、常数赋值给tensor

group_idx[batch_indices, row_indices, itself_indices] = 1024

group_idx=tensor([[[1024,    1,    2,  ..., 1021, 1022, 1023],
         [   0, 1024,    2,  ..., 1021, 1022, 1023],
         [   0,    1, 1024,  ..., 1021, 1022, 1023],
         ...,
         [   0,    1,    2,  ..., 1024, 1022, 1023],
         [   0,    1,    2,  ..., 1021, 1024, 1023],
         [   0,    1,    2,  ..., 1021, 1022, 1024]],

        [[1024,    1,    2,  ..., 1021, 1022, 1023],
         [   0, 1024,    2,  ..., 1021, 1022, 1023],
         [   0,    1, 1024,  ..., 1021, 1022, 1023],
         ...,
         [   0,    1,    2,  ..., 1024, 1022, 1023],
         [   0,    1,    2,  ..., 1021, 1024, 1023],
         [   0,    1,    2,  ..., 1021, 1022, 1024]],

        [[1024,    1,    2,  ..., 1021, 1022, 1023],
         [   0, 1024,    2,  ..., 1021, 1022, 1023],
         [   0,    1, 1024,  ..., 1021, 1022, 1023],
         ...,
         [   0,    1,    2,  ..., 1024, 1022, 1023],
         [   0,    1,    2,  ..., 1021, 1024, 1023],
         [   0,    1,    2,  ..., 1021, 1022, 1024]],

        [[1024,    1,    2,  ..., 1021, 1022, 1023],
         [   0, 1024,    2,  ..., 1021, 1022, 1023],
         [   0,    1, 1024,  ..., 1021, 1022, 1023],
         ...,
         [   0,    1,    2,  ..., 1024, 1022, 1023],
         [   0,    1,    2,  ..., 1021, 1024, 1023],
         [   0,    1,    2,  ..., 1021, 1022, 1024]]], device='cuda:0')

3.常数和tensor比较大小

//相同两个矩阵的距离矩阵,对角线元素为相同点间的距离为0,矩阵沿主对角线对称,因为第2个点和第1个点的距离等于第1个点和第2个点的距离,也即sqrdists[0][1][0]=sqrdists[0][0][1].sqrdists[0]代表第一个batch.
//sqrdists:4*1024*1024
sqrdists=tensor([[[ 0.0000e+00,  1.1386e-01,  6.3624e-01,  ...,  1.3848e+00,
           5.5936e-01,  1.0839e+00],
         [ 1.1386e-01,  0.0000e+00,  8.9093e-01,  ...,  2.2411e+00,
           7.8311e-01,  1.8543e+00],
         [ 6.3624e-01,  8.9093e-01, -5.9605e-08,  ...,  1.5017e+00,
           1.5127e-02,  1.2706e+00],
         ...,
         [ 1.3848e+00,  2.2411e+00,  1.5017e+00,  ..., -5.9605e-08,
           1.6157e+00,  1.8451e-02],
         [ 5.5936e-01,  7.8311e-01,  1.5127e-02,  ...,  1.6157e+00,
           0.0000e+00,  1.3624e+00],
         [ 1.0839e+00,  1.8543e+00,  1.2706e+00,  ...,  1.8451e-02,
           1.3624e+00,  5.9605e-08]],

        [[ 0.0000e+00,  1.5277e+00,  6.1708e-01,  ...,  9.5098e-01,
           2.1011e-01,  1.2949e-02],
         [ 1.5277e+00,  0.0000e+00,  3.4608e+00,  ...,  9.0734e-02,
           1.2185e+00,  1.2821e+00],
         [ 6.1708e-01,  3.4608e+00,  1.1921e-07,  ...,  2.7502e+00,
           1.5215e+00,  7.9318e-01],
         ...,
         [ 9.5098e-01,  9.0734e-02,  2.7502e+00,  ...,  0.0000e+00,
           6.4702e-01,  7.4993e-01],
         [ 2.1011e-01,  1.2185e+00,  1.5215e+00,  ...,  6.4702e-01,
           0.0000e+00,  1.4260e-01],
         [ 1.2949e-02,  1.2821e+00,  7.9318e-01,  ...,  7.4993e-01,
           1.4260e-01,  3.7253e-09]],

        [[-1.1921e-07,  6.5509e-01,  5.0194e-01,  ...,  1.9291e-01,
           1.1051e+00,  5.2810e-01],
         [ 6.5509e-01,  0.0000e+00,  1.4901e-01,  ...,  1.8733e-01,
           1.7176e-01,  1.5647e-01],
         [ 5.0194e-01,  1.4901e-01,  0.0000e+00,  ...,  1.9878e-01,
           1.3537e-01,  3.1886e-01],
         ...,
         [ 1.9291e-01,  1.8733e-01,  1.9878e-01,  ..., -2.9802e-08,
           4.8584e-01,  3.0713e-01],
         [ 1.1051e+00,  1.7176e-01,  1.3537e-01,  ...,  4.8584e-01,
           2.9802e-08,  5.2555e-01],
         [ 5.2810e-01,  1.5647e-01,  3.1886e-01,  ...,  3.0713e-01,
           5.2555e-01,  0.0000e+00]],

        [[-1.1921e-07,  1.5104e+00,  3.3076e+00,  ...,  2.0143e-01,
           1.0395e-01,  2.4303e+00],
         [ 1.5104e+00,  0.0000e+00,  8.1933e-01,  ...,  7.9454e-01,
           1.2310e+00,  7.8964e-01],
         [ 3.3076e+00,  8.1933e-01,  0.0000e+00,  ...,  1.9282e+00,
           2.4735e+00,  4.4241e-01],
         ...,
         [ 2.0143e-01,  7.9454e-01,  1.9282e+00,  ..., -5.9605e-08,
           9.0249e-02,  1.4813e+00],
         [ 1.0395e-01,  1.2310e+00,  2.4735e+00,  ...,  9.0249e-02,
           0.0000e+00,  1.6600e+00],
         [ 2.4303e+00,  7.8964e-01,  4.4241e-01,  ...,  1.4813e+00,
           1.6600e+00, -5.9605e-08]]], device='cuda:0')
radius=0.3 
temp2=sqrdists>radius ** 2
temp2=tensor([[[False,  True,  True,  ...,  True,  True,  True],
         [ True, False,  True,  ...,  True,  True,  True],
         [ True,  True, False,  ...,  True, False,  True],
         ...,
         [ True,  True,  True,  ..., False,  True, False],
         [ True,  True, False,  ...,  True, False,  True],
         [ True,  True,  True,  ..., False,  True, False]],

        [[False,  True,  True,  ...,  True,  True, False],
         [ True, False,  True,  ...,  True,  True,  True],
         [ True,  True, False,  ...,  True,  True,  True],
         ...,
         [ True,  True,  True,  ..., False,  True,  True],
         [ True,  True,  True,  ...,  True, False,  True],
         [False,  True,  True,  ...,  True,  True, False]],

        [[False,  True,  True,  ...,  True,  True,  True],
         [ True, False,  True,  ...,  True,  True,  True],
         [ True,  True, False,  ...,  True,  True,  True],
         ...,
         [ True,  True,  True,  ..., False,  True,  True],
         [ True,  True,  True,  ...,  True, False,  True],
         [ True,  True,  True,  ...,  True,  True, False]],

        [[False,  True,  True,  ...,  True,  True,  True],
         [ True, False,  True,  ...,  True,  True,  True],
         [ True,  True, False,  ...,  True,  True,  True],
         ...,
         [ True,  True,  True,  ..., False,  True,  True],
         [ True,  True,  True,  ...,  True, False,  True],
         [ True,  True,  True,  ...,  True,  True, False]]], device='cuda:0')

4. 常数和tensor比较大小后作为tensor索引

temp3=group_idx[sqrdists > radius ** 2]
temp3=tensor([   1,    2,    3,  ..., 1020, 1021, 1022], device='cuda:0')

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值