使用pytorch实现tf.batch_gather
:
def batch_gather(data:torch.Tensor, index:torch.Tensor):
length = index.shape[0]
t_index = index.data.numpy()
t_data = data.data.numpy()
result = []
for i in range(length):
result.append(t_data[i, t_index[i], :])
return torch.from_numpy(np.array(result))