fluid.layers.gather(input, index, overwrite=True)
- 作用:根据索引 index 获取输入(input)的最外层维度的条目,并将它们拼接在一起。
- 链接:pp飞桨API说明
- 对比:通常来说此api用法和pytorch没什么特殊,但因为API缺少了dim参数,使得当多维tensor索引时报错
示例
输入维度
- X.shape = [b, c, 81204]
- Index.shape = [b, c, 720000]
输出维度 - X_offset = [b, c, 720000]
# pytorch code
x_offset = x.gather(dim=-1, index=index)
# paddlepaddle code
x_offset = fluid.layers.concat([
fluid.layers.reshape(
fluid.layers.concat([
fluid.layers.reshape(fluid.layers.gather(x[i, j, :], index[i, j, :]), [1, -1])
for j in range(c)], axis=0)
, [1, c, -1]) for i in range(b)], axis=0)
总结:由于padddle中,gather的用法缺少dim,所以要曲线救国,采用for、reshape和concat相结合的方法处理。