import torch
token = [
[101, 2, 3, 4, 5, 102, 0, 0],
[101, 1, 4, 3, 102, 0, 0, 0],
[101, 4, 2, 2, 5, 7, 102, 0],
[101, 1, 3, 4, 102, 0, 0, 0]
]
emb = torch.tensor([[[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1238, 0.3628, 0.8905, 0.0051, 0.3695],
[0.7803, 0.3018, 0.7711, 0.8743, 0.8260],
[0.8933, 0.0041, 0.0405, 0.4883, 0.6772],
[0.7019, 0.5739, 0.7999, 0.2417, 0.0757],
[0.2, 0.2, 0.2, 0.2, 0.2],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]],
[[0.1, 0.1, 0.1, 0.1, 0.1],
[0.4643, 0.3200, 0.3301, 0.6153, 0.0311],
[0.1726, 0.1853, 0.6196, 0.6015, 0.5596],
[0.0537, 0.0339, 0.4433, 0.6358, 0.9984],
[0.2, 0.2, 0.2, 0.2, 0.2],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]],
[[0.1, 0.1, 0.1, 0.1, 0.1],
[0.9288, 0.7993, 0.8453, 0.5213, 0.7907],
[0.5570, 0.0945, 0.9134, 0.8189, 0.1070],
[0.3729, 0.0044, 0.2153, 0.0338, 0.3157],
[0.6377, 0.3992, 0.9940, 0.6390, 0.1803],
[0.2793, 0.8569, 0.2933, 0.6693, 0.7314],
[0.2, 0.2, 0.2, 0.2, 0.2],
[0.0, 0.0, 0.0, 0.0, 0.0]],
[[0.1, 0.1, 0.1, 0.1, 0.1],
[0.7950, 0.5519, 0.2370, 0.9076, 0.2815],
[0.8589, 0.4967, 0.0160, 0.4967, 0.0296],
[0.4961, 0.0019, 0.9717, 0.5534, 0.2273],
[0.2, 0.2, 0.2, 0.2, 0.2],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]]])
#emb的原索引转为需要保留的索引
x_len = [6, 5, 7, 5]
orgi_total_batch = [4, 8, 5]
out_idx = []
idx_select = []
for batch_size in range(orgi_total_batch[0]):
b_ = []
b2_ = []
for seq in range(orgi_total_batch[1]):
if str(seq) != '0' and str(seq) != str(x_len[batch_size]-1):
b1_ = []
for _dim in range(orgi_total_batch[2]):
b_.append((seq, (batch_size, seq, _dim)))
# out_idx.append((seq, (batch_size, seq, _dim)))
b1_.append(seq)
b2_.append(b1_)
idx_select.append(b2_)
out_idx.append(b_)
idx_select = torch.LongTensor(idx_select)
print(idx)
total_batch = [4, 6, 5]
idx_list = []
for batch_size in range(total_batch[0]):
for seq in range(total_batch[1]):
for _dim in range(total_batch[2]):
idx_list.append((idx_select[batch_size][seq][_dim], (batch_size, seq, _dim),
(batch_size, idx_select[batch_size][seq][_dim].tolist(), _dim)))
# idx_list表示 (batch_size, seq, _dim) 的位置 dim=1 的值为 idx[batch_size][seq][_dim],
# (batch_size, idx[batch_size][seq][_dim].tolist(), _dim)为最终要取的索引
out = torch.gather(emb, 1, idx_select)
print(idx_list)
orgi_total_batch = [4, 8, 5]
out_ = []
for batch_size in range(orgi_total_batch[0]):
for seq in range(orgi_total_batch[1]):
for _dim in range(orgi_total_batch[2]):
out_.append((emb[batch_size][seq][_dim], (batch_size, seq, _dim)))
print(out_)
'''
out 的第0 行的数值为input batch=[0,0,0,0,0],seq =[1,1,1,1,1] dim=[0,1,2,3,4]
index[0][0][0,1,2,3,4] = [1,1,1,1,1]
out[0][1][0] = input[0][ index[0][0][0] ][0] = input[0][1][0] = 0.7803
out[0][1][1] = input[0][ index[0][0][1] ][1] = input[0][1][1] = 0.3018
out[0][1][2] = input[0][ index[0][0][2] ][2] = input[0][1][2] = 0.7711
out[0][1][3] = input[0][ index[0][0][3] ][3] = input[0][1][3] = 0.8743
out[0][1][4] = input[0][ index[0][0][4] ][4] = input[0][1][4] = 0.8260
'''