目标需求
我有一个维度是 [512, 20, 128] 的tensor,我希望从512的batch里,每次按照[20] 这一维给定索引,得到一个128维的向量,然后遍历整个batch,最终得到 [512, 128] 的tensor。
如果用循环来实现就很简单:
data = torch.rand([512, 20, 128]) # (batch, idx, embd)
index = torch.randint(20, (512, )) # (batch,)
for i in range(512):
tmp = data[i].index_select(0, index[i])
try:
out = torch.cat((out, tmp), dim=0)
except:
out = tmp
print(out.shape) # [512, 128]
如果不用循环来实现呢?
import torch
data = torch.rand([512, 20, 128]) # (batch, idx, embd)
index = torch.randint(20, (512, )) # (batch,)
index_new = index[..., None, None].expand(-1, -1, data.shape[2])
out = torch.gather(data, 1, index_new).squeeze()
print(out.shape)
解读代码,index_select很方便,但只能取出一维。题目是先要按512维的顺序取出[20, 128],然后根据第一个给定索引取出一个128维向量,接着取出第二个[20, 128],然后根据第二个索引取出一个128维向量,最终得到一个[512, 128]维tensor。
借助gather的话,需要保证data和index的维度一致,因此我们需要对给定的index进行扩充,同时还要对第一维以外的维度进行复制扩充。