torch.gather()
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
作用:沿着由dim指定的轴收集数值。
- 参数
input (Tensor) – 目标变量,输入
dim (int) – 需要沿着取值的坐标轴
index (LongTensor) – 需要取值的索引矩阵
sparse_grad (bool,optional) – 如果为真,输入将是一个稀疏张量
out (Tensor, optional) – 输出
- 简单例子
>>> t = torch.tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
tensor([[ 1, 1],
[ 4, 3]])
解释与例子
torch.gather迷惑大家的点主要是在dim的问题上,大家不知道怎么个取值法。 它的官方文档不太好理解,举得例子也太过于简单,下面我将举个例子详细说明gather的原理。
- dim=0的情况
input = [
[0.0, 0.1, 0.2, 0.3],
[1.0, 1.1, 1.2, 1.3],
[2.0, 2.1, 2.2, 2.3]
]#shape [3,4]
input = torch.tensor(input)
length = torch.LongTensor([
[2,2,2,2],
[1,1,1,1],
[0,0,0,0],
[0,1,2,0]
])#[4,4]
out = torch.gather(input, dim=0, index=length)
print(out)
结果
tensor([[2.0000, 2.1000, 2.2000, 2.3000],
[1.0000, 1.1000, 1.2000, 1.3000],
[0.0000, 0.1000, 0.2000, 0.3000],
[0.0000, 1.1000, 2.2000, 0.3000]])
这里直接给到大家对应的取值矩阵。
gather函数就是通过length矩阵也就是我上述的矩阵来进行取值的。
例
如
X
20
就
代
表
了
i
n
p
u
t
的
第
2
行
第
0
列
的
2.0
例如X_{20}就代表了input的第2行第0列的2.0
例如X20就代表了input的第2行第0列的2.0
因为dim=0代表的是横向,按行取值。
即length矩阵中的数的值代表的是行数,数的位置代表的列数,比如length矩阵中的第三行第三列(从0数起)的数 0,其值是0,代表在input中所取的数是第0行,位置是第三列,则表示在input中所取的数是第三列
即
X
03
=
0.3
X_{03} =0.3
X03=0.3
- dim=1
input = [
[0.0, 0.1, 0.2, 0.3],
[1.0, 1.1, 1.2, 1.3],
[2.0, 2.1, 2.2, 2.3]
]#shape [3,4]
input = torch.tensor(input)
length = torch.LongTensor([
[2,2,2,2],
[1,1,1,1],
[0,1,2,0]
])#[3,4]
out = torch.gather(input, dim=1, index=length)
print(out)
结果
tensor([[0.2000, 0.2000, 0.2000, 0.2000],
[1.1000, 1.1000, 1.1000, 1.1000],
[2.0000, 2.1000, 2.2000, 2.0000]])
同理与dim=0时相反,数值代表列数,位置代表行数。所以对应的取值矩阵是
[
X
02
X
02
X
02
X
02
X
11
X
11
X
11
X
11
X
20
X
21
X
22
X
20
]
\begin{bmatrix} X_{02}& X_{02}& X_{02}& X_{02}\\ X_{11}& X_{11}& X_{11}& X_{11}\\ X_{20}& X_{21}& X_{22}& X_{20} \end{bmatrix}
⎣⎡X02X11X20X02X11X21X02X11X22X02X11X20⎦⎤
总结一下
-
dim=0
length矩阵中的数的值代表的是行数,所处位置代表的是列数。
所以此时length矩阵里面的数的值应当小于input.shape[0]=3(不能等于因为是从0计数),length矩阵的的最大列数应当小于等于input.shape[1]=4(此处可以等于因为是列数) -
dim=1
length矩阵中的数的值代表的是列数,所处位置代表的是行数。
所以此时length矩阵里面的数的值应当小于input.shape[1]=4(不能等于因为是从0计数),length矩阵的的最大行数应当小于等于input.shape[0]=3(此处可以等于因为是行数)
如有疑问欢迎留言询问哦。