索引操作函数gather函数详解
事先声明:本文只会对二维张量的gather操作进行介绍,三维张量的gather操作规则在csdn上的博文屡见不鲜。本文的解释是从个人的理解出发,相信解释也会对理解三维张量的操作规则起到触类旁通的作用。
gather函数的输出规则
o
u
t
[
i
]
[
j
]
=
i
n
p
u
t
[
i
n
d
e
x
[
i
]
[
j
]
]
[
j
]
,
i
f
d
i
m
=
=
0
out [i] [j] = input [index [i] [j] ] [j], if{\ }dim == 0
out[i][j]=input[index[i][j]][j],if dim==0
o
u
t
[
i
]
[
j
]
=
i
n
p
u
t
[
i
]
[
i
n
d
e
x
[
i
]
[
j
]
]
,
i
f
d
i
m
=
=
1
out [i] [j] = input [i] [index [i] [j] ], if{\ }dim == 1
out[i][j]=input[i][index[i][j]],if dim==1
第一条规则
从行的角度出发,输入的index张量按照如上规则,取出对应的输入张量的元素。
例如:一个
2
∗
3
2*3
2∗3的维度的张量,index为
[
0
,
1
,
1
]
[0, 1, 1]
[0,1,1],取
d
i
m
=
0
dim=0
dim=0,根据规则,外层循环为变量
i
i
i,内层循环为变量
j
j
j,且
i
i
n
r
a
n
g
e
(
0
,
2
)
;
j
i
n
r
a
n
g
e
(
0
,
3
)
i {\ }in{\ } range(0, 2); j{\ } in{\ } range(0, 3)
i in range(0,2);j in range(0,3)。
代入
i
=
0
,
j
=
1
i=0,{\ }j=1
i=0, j=1,得到:
o
u
t
[
0
]
[
1
]
=
i
n
p
u
t
[
i
n
d
e
x
[
0
]
[
1
]
]
[
1
]
out[0][1]=input[index[0][1]][1]
out[0][1]=input[index[0][1]][1]
o
u
t
[
0
]
[
1
]
=
i
n
p
u
t
[
1
]
[
1
]
out[0][1]=input[1][1]
out[0][1]=input[1][1]
即:该输出元素为输入的
2
∗
3
2*3
2∗3维度张量的第1行第1列元素。且该元素在输出张量中处在第0行第1列的位置。
如下表所示:
0 | 1 | 2 | |
---|---|---|---|
0 | |||
1 | this element |
其中, 0 , 1 , 2 {0, 1, 2} 0,1,2代表列标号, 0 , 1 {0, 1} 0,1代表行标号。
第二条规则
从列的角度出发,输入的index张量按照如上规则,取出对应的输入张量的元素。
例如:一个
2
∗
3
2*3
2∗3的维度的张量,index为
[
[
0
,
1
,
1
]
,
[
1
,
1
,
1
]
]
[[0, 1, 1],[1, 1, 1]]
[[0,1,1],[1,1,1]],取
d
i
m
=
1
dim=1
dim=1,根据规则,外层循环为变量
i
i
i,内层循环为变量
j
j
j,且
i
i
n
r
a
n
g
e
(
0
,
2
)
;
j
i
n
r
a
n
g
e
(
0
,
3
)
i {\ }in{\ } range(0, 2); j{\ } in{\ } range(0, 3)
i in range(0,2);j in range(0,3)。
代入
i
=
1
,
j
=
1
i=1,{\ }j=1
i=1, j=1,得到:
o
u
t
[
1
]
[
1
]
=
i
n
p
u
t
[
1
]
[
i
n
d
e
x
[
1
]
[
1
]
]
out[1][1]=input[1][index[1][1]]
out[1][1]=input[1][index[1][1]]
o
u
t
[
1
]
[
1
]
=
i
n
p
u
t
[
1
]
[
1
]
out[1][1]=input[1][1]
out[1][1]=input[1][1]
即:该输出元素为输入的
2
∗
3
2*3
2∗3维度张量的第1行第1列元素,且该元素在输出张量中处在第1行第1列的位置。
如下表所示:
0 | 1 | 2 | |
---|---|---|---|
0 | |||
1 | this element |
其中, 0 , 1 , 2 {0, 1, 2} 0,1,2代表列标号, 0 , 1 {0, 1} 0,1代表行标号。
gather函数内部的代码机理推测
声明:下述代码仅针对原理部分编写,距离函数内部真实情况仍存在较大差距,且下述代码的严谨性不够,故仅供理解gather的核心规则。
def gather(input, dim, index):
# 这里的dim要求取0或1
out = []
m = input.size()[0] # size函数是torch的方法
n = input.size()[1]
for i in range(m):
for j in range(n):
if dim == 0:
out [i] [j] = input [index [i] [j] ] [j]
if dim == 1:
out [i] [j] = input [i] [index [i] [j] ]
return out
代码示例
与上一篇博文内容相同,这里再次展示一遍。
import torch
# 设置一个随机种子
torch.manual_seed(100)
# 生成一个形状为2*3的矩阵
x = torch.randn(2, 3)
print(x)
# 获取指定索引对应的值
index = torch.LongTensor([[0, 1, 1]])
print(torch.gather(x, 0, index))
index = torch.LongTensor([[0, 1, 1], [1, 1, 1]])
a = torch.gather(x, 1, index)
print(a)
输出结果
参考文献
吴茂贵,郁明敏,杨本法,李涛,张粤磊. Python深度学习(基于Pytorch). 北京:机械工业出版社,2019.