torch.gether()用法
代码展示
b = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(b)
index_1 = torch.tensor([[0, 0, 0], [1, 1, 1]]).repeat(5, 1)
index_2 = torch.tensor([[0, 1, 1, 2, 2, 2, 1], [1, 1, 1, 0, 0, 2, 2]])
print(torch.gather(b, dim=1, index=index_2))
print(torch.gather(b, dim=0, index=index_1))
结果
tensor([[1, 2, 3],
[4, 5, 6]])
tensor([[1, 2, 2, 3, 3, 3, 2],
[5, 5, 5, 4, 4, 6, 6]])
tensor([[1, 2, 3],
[4, 5, 6],
[1, 2, 3],
[4, 5, 6],
[1, 2, 3],
[4, 5, 6],
[1, 2, 3],
[4, 5, 6],
[1, 2, 3],
[4, 5, 6]])
分析用法:
torch.gether()有三个参数, 第一个为src即要被索引的张量; 第二个dim,意思在哪个维度上去索引;第三个index,索引张量。
如果想成功的用gather():想在哪个维度上索引,哪个维度就可以灵活变化,可以在1-n上去边,如上面例子
torch.gather(b, dim=1, index=index_2),dim=1,所以列数可以灵活变化,原张量的形状是(2, 3),索引形状可以是(2,n),其中n>=1,形状中对应元素的取值必须小于原张量列的最大值,如上述索引的值最大为2。
torch.gather(b, dim=0, index=index_1)的原理亦是如此,索引的最大值为1
假如一个张量的形状为src.shape(a, b, c):
当dim=0, index.shape(n,b,c), 其中n>=1即可
当dim=1, index.shape(a, n ,c),其中n>=1即可
当dim=2, index.shape(a,b,n) ,其中n>=1即可