定义:从原tensor中获取指定dim和指定index的数据,生成新的tensor
输入
import torch
tensor_0 = torch.arange(3, 12).view(3, 3)
tensor_0
tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
Example-1
dim: 0
index_1 = torch.tensor([[2]])
tensor_1 = tensor_0.gather(0, index_1)
print("tensor_1", tensor_1)
index_2 = torch.tensor([[2, 1]])
tensor_2 = tensor_0.gather(0, index_2)
print("tensor_2", tensor_2)
index_3 = torch.tensor([[2, 1, 0]])
tensor_3 = tensor_0.gather(0, index_3)
print("tensor_3", tensor_3)
tensor_1 tensor([[9]]) # 对应下标[(2,0)]
tensor_2 tensor([[9, 7]]) # 对应下标[(2,0), (1,1)]
tensor_3 tensor([[9, 7, 5]]) # 对应下标[(2,0), (1,1), (0,2)]
理解
d
i
m
=
0
,
i
n
d
e
x
=
t
o
r
c
h
.
t
e
n
s
o
r
(
[
[
2
,
1
,
0
]
]
)
dim=0,index=torch.tensor([[2, 1, 0]])
dim=0,index=torch.tensor([[2,1,0]]):表示将取出下标
[
(
0
,
0
)
,
(
0
,
1
)
,
(
0
,
2
)
]
[(0,0), (0,1), (0,2)]
[(0,0),(0,1),(0,2)]的
d
i
m
=
1
dim=1
dim=1维度不变,
d
i
m
=
0
dim=0
dim=0维度根据
i
n
d
e
x
index
index修改为
[
(
2
,
0
)
,
(
1
,
1
)
,
(
0
,
2
)
]
[(2,0), (1,1), (0,2)]
[(2,0),(1,1),(0,2)]
Example-2
dim: 0
index_4 = torch.tensor([[2]])
tensor_4 = tensor_0.gather(1, index_4)
print("tensor_4", tensor_4)
index_5 = torch.tensor([[2, 1]])
tensor_5 = tensor_0.gather(1, index_5)
print("tensor_5", tensor_5)
index_6 = torch.tensor([[2, 1, 0]])
tensor_6 = tensor_0.gather(1, index_6)
print("tensor_6", tensor_6)
tensor_4 tensor([[5]]) # 对应下标[(0,2)]
tensor_5 tensor([[5, 4]]) # 对应下标[(0,2), (0,1)]
tensor_6 tensor([[5, 4, 3]]) # 对应下标[(0,2), (0,1), (0, 0)]
理解
d
i
m
=
1
,
i
n
d
e
x
=
t
o
r
c
h
.
t
e
n
s
o
r
(
[
[
2
,
1
,
0
]
]
)
dim=1,index=torch.tensor([[2, 1, 0]])
dim=1,index=torch.tensor([[2,1,0]]):表示将取出下标
[
(
0
,
0
)
,
(
0
,
1
)
,
(
0
,
2
)
]
[(0,0), (0,1), (0,2)]
[(0,0),(0,1),(0,2)]的
d
i
m
=
0
dim=0
dim=0维度不变,
d
i
m
=
1
dim=1
dim=1维度根据
i
n
d
e
x
index
index修改为
[
(
0
,
2
)
,
(
0
,
1
)
,
(
0
,
0
)
]
[(0,2), (0,1), (0,0)]
[(0,2),(0,1),(0,0)]
Example-3
dim: 0
index_7 = torch.tensor([[0, 2], [1, 2], [0, 2]])
tensor_7 = tensor_0.gather(1, index_7)
tensor_7
tensor([[ 3, 5],
[ 7, 8],
[ 9, 11]])
理解
d
i
m
=
1
,
i
n
d
e
x
=
t
o
r
c
h
.
t
e
n
s
o
r
(
[
[
0
,
2
]
,
[
1
,
2
]
,
[
0
,
2
]
]
)
dim=1,index=torch.tensor([[0, 2], [1, 2], [0, 2]])
dim=1,index=torch.tensor([[0,2],[1,2],[0,2]]):表示将取出下标
[
[
(
0
,
0
)
,
(
0
,
1
)
]
,
[
(
1
,
0
)
,
(
1
,
1
)
]
,
[
(
2
,
0
)
,
(
2
,
1
)
]
]
[[(0,0), (0,1)], [(1, 0) , (1, 1)], [(2, 0), (2, 1)]]
[[(0,0),(0,1)],[(1,0),(1,1)],[(2,0),(2,1)]]的
d
i
m
=
0
dim=0
dim=0维度不变,
d
i
m
=
1
dim=1
dim=1维度根据
i
n
d
e
x
index
index修改为
[
[
(
0
,
0
)
,
(
0
,
2
)
]
,
[
(
1
,
1
)
,
(
1
,
2
)
]
,
[
(
2
,
0
)
,
(
2
,
2
)
]
]
[[(0,0), (0,2)], [(1, 1) , (1, 2)], [(2, 0), (2, 2)]]
[[(0,0),(0,2)],[(1,1),(1,2)],[(2,0),(2,2)]]
- tensor
tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
- 取值结果
修改前下标 | index | 修改后下标 | 对应的数值 |
---|---|---|---|
[(0,0), (0,1)] | [0, 2] | [(0, 0), (0, 2)] | [ 3, 5] |
[(1, 0) , (1, 1)] | [1, 2] | [(1, 1) , (1, 2)] | [ 7, 8] |
[(2, 0), (2, 1)] | [0, 2] | [(2, 0), (2, 2)] | [ 9, 11] |