torch.gather的操作分解

import torch
token = [
    [101, 2, 3, 4, 5, 102, 0, 0],
    [101, 1, 4, 3, 102, 0, 0, 0],
    [101, 4, 2, 2, 5, 7, 102, 0],
    [101, 1, 3, 4, 102, 0, 0, 0]
]


emb = torch.tensor([[[0.1, 0.1, 0.1, 0.1, 0.1],
         [0.1238, 0.3628, 0.8905, 0.0051, 0.3695],
         [0.7803, 0.3018, 0.7711, 0.8743, 0.8260],
         [0.8933, 0.0041, 0.0405, 0.4883, 0.6772],
         [0.7019, 0.5739, 0.7999, 0.2417, 0.0757],
         [0.2, 0.2, 0.2, 0.2, 0.2],
         [0.0, 0.0, 0.0, 0.0, 0.0],
         [0.0, 0.0, 0.0, 0.0, 0.0]],

        [[0.1, 0.1, 0.1, 0.1, 0.1],
         [0.4643, 0.3200, 0.3301, 0.6153, 0.0311],
         [0.1726, 0.1853, 0.6196, 0.6015, 0.5596],
         [0.0537, 0.0339, 0.4433, 0.6358, 0.9984],
         [0.2, 0.2, 0.2, 0.2, 0.2],
         [0.0, 0.0, 0.0, 0.0, 0.0],
         [0.0, 0.0, 0.0, 0.0, 0.0],
         [0.0, 0.0, 0.0, 0.0, 0.0]],

        [[0.1, 0.1, 0.1, 0.1, 0.1],
         [0.9288, 0.7993, 0.8453, 0.5213, 0.7907],
         [0.5570, 0.0945, 0.9134, 0.8189, 0.1070],
         [0.3729, 0.0044, 0.2153, 0.0338, 0.3157],
         [0.6377, 0.3992, 0.9940, 0.6390, 0.1803],
         [0.2793, 0.8569, 0.2933, 0.6693, 0.7314],
         [0.2, 0.2, 0.2, 0.2, 0.2],
         [0.0, 0.0, 0.0, 0.0, 0.0]],

        [[0.1, 0.1, 0.1, 0.1, 0.1],
         [0.7950, 0.5519, 0.2370, 0.9076, 0.2815],
         [0.8589, 0.4967, 0.0160, 0.4967, 0.0296],
         [0.4961, 0.0019, 0.9717, 0.5534, 0.2273],
         [0.2, 0.2, 0.2, 0.2, 0.2],
         [0.0, 0.0, 0.0, 0.0, 0.0],
         [0.0, 0.0, 0.0, 0.0, 0.0],
         [0.0, 0.0, 0.0, 0.0, 0.0]]])

#emb的原索引转为需要保留的索引

x_len = [6, 5, 7, 5]
orgi_total_batch = [4, 8, 5]
out_idx = []
idx_select = []
for batch_size in range(orgi_total_batch[0]):
    b_ = []
    b2_ = []
    for seq in range(orgi_total_batch[1]):
        if str(seq) != '0' and str(seq) != str(x_len[batch_size]-1):
            b1_ = []
            for _dim in range(orgi_total_batch[2]):
                b_.append((seq, (batch_size, seq, _dim)))
                # out_idx.append((seq, (batch_size, seq, _dim)))
                b1_.append(seq)
            b2_.append(b1_)

    idx_select.append(b2_)
    out_idx.append(b_)


idx_select = torch.LongTensor(idx_select)
print(idx)

total_batch = [4, 6, 5]
idx_list = []
for batch_size in range(total_batch[0]):
    for seq in range(total_batch[1]):
        for _dim in range(total_batch[2]):
            idx_list.append((idx_select[batch_size][seq][_dim], (batch_size, seq, _dim),
                             (batch_size, idx_select[batch_size][seq][_dim].tolist(), _dim)))

# idx_list表示  (batch_size, seq, _dim) 的位置 dim=1 的值为 idx[batch_size][seq][_dim],
# (batch_size, idx[batch_size][seq][_dim].tolist(), _dim)为最终要取的索引

out = torch.gather(emb, 1, idx_select)
print(idx_list)
orgi_total_batch = [4, 8, 5]
out_ = []
for batch_size in range(orgi_total_batch[0]):
    for seq in range(orgi_total_batch[1]):
        for _dim in range(orgi_total_batch[2]):
            out_.append((emb[batch_size][seq][_dim], (batch_size, seq, _dim)))
print(out_)



'''
out 的第0 行的数值为input batch=[0,0,0,0,0],seq =[1,1,1,1,1] dim=[0,1,2,3,4]
index[0][0][0,1,2,3,4] = [1,1,1,1,1]
out[0][1][0] = input[0][ index[0][0][0] ][0] = input[0][1][0] = 0.7803
out[0][1][1] = input[0][ index[0][0][1] ][1] = input[0][1][1] = 0.3018
out[0][1][2] = input[0][ index[0][0][2] ][2] = input[0][1][2] = 0.7711
out[0][1][3] = input[0][ index[0][0][3] ][3] = input[0][1][3] = 0.8743
out[0][1][4] = input[0][ index[0][0][4] ][4] = input[0][1][4] = 0.8260

'''

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值