看了好几篇了,没有直接看明白,特梳理之
功能
数据收集,函数torch.gather(input, dim, index, out=None) → Tensor
沿给定轴 dim ,将输入索引张量 index 指定位置的值进行聚合.
理解
对于一个shape为(3,4)的数据a,可用索引a[1,2]取 a[1][2]数据;即通过a[i,j]的方式可以获取数据,gather则通过类似方式收集数据。
通过示例理解gather具体的方式,gather则内容更丰富。
结论
Gather获取的数据shape和idx索引的shape一样,获取数据的内容为data中idx对应的坐标,idx对应的坐标按指定轴被idx坐标处值替换后的坐标。
idx数组的shape不一定小于data的,只要对应的坐标构造后合法即可。
看后面示例
1维数据
dim 只能取 0;
>>> a=torch.arange(1,12,2)
>>> a
tensor([ 1, 3, 5, 7, 9, 11])
>>> idx = torch.tensor([1,3,5])
>>> a.gather(0,idx)
tensor([ 3, 7, 11])
取第dim=0,的第1,3,5个数据,可以读取任意位置 任意数量的元素。
2维数据
dim能取0,1;
0轴取数据示例,通过构造的坐标理解gather 的原理
1轴取数据示例
3维数据
3维数据0轴
3维数据1轴
3维数据2轴
应用: 我从RL中看到这个函数,其他的应用大家补充⑧