PyTorch | 索引与切片
1. 普通索引
import torch
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
photo = torch.rand(4,3,28,28)
photo.shape
photo[2].shape # 索引到第一个维度
photo[3,2].shape # 索引到第二个维度
photo[1,2,3].shape # 索引到第三个维度
photo[1,2,3,4] # 索引到具体元素,返回一个标量
2. 冒号索引(切片)
import torch
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
photo = torch.rand(4,3,28,28)
photo.shape
# 索引到第一维度0,1(不包含2)
photo[:2].shape
# 索引到第一维度0,1(不包含2);第二维度2
photo[:2, 2:].shape
# 索引到第一个维度0,1(不包含2);第二个维度1,2;第三、四维度所有
photo[:2, 1:, :, :].shape
# 索引到第一个维度0,1(不包含2);第二个维度2;第三维度从倒数第三到最后;第四维度所有
photo[:2, 2:, -3:, :].shape
# 索引到第一个维度0,1(不包含2);第二个维度2;第三维度从0到27,每两步取一个元素;第四维度所有
photo[:2, 2:, 0:28:2, :].shape
3. index_select 选择特定索引
-
torch.index_select(input, dim, index)
:沿指定维度 d i m dim dim 对输入进行切片,取 i n d e x index index 中指定的相应项,然后返回一个新的张量,返回的张量与原始张量有相同的维度(在指定轴上),返回的张量与原始张量不共享内存空间。下面 p h o t o photo photo 这个 t e n s o r tensor tensor 可以看作 7 7 7 张 R G B ( 3 通 道 ) RGB(3通道) RGB(3通道) 的 M N I S T MNIST MNIST 图像,长宽都是 28 p x 28px 28px。那么在第一维度上可以选择特定的图片,在第二维度上选择特定的通道,在第三维度上选择特定的行,在第四维度上选择特定的列等。
import torch from IPython.core.interactiveshell import InteractiveShell InteractiveShell.ast_node_interactivity = "all" photo = torch.rand(7,3,28,28) photo.shape # 选择第一张和第三张图 photo.index_select(0, torch.tensor([0, 2])).shape # 选择R通道和B通道 photo.index_select(1, torch.tensor([0, 2])).shape # 选择图像的0~8行 photo.index_select(2, torch.arange(8)).shape # 选择图像的0~8列 photo.index_select(3, torch.arange(8)).shape
4. masked_select 选择符合条件的索引
torch.masked_select(input,mask)
:根据一个布尔掩码 ( b o o l e a n m a s k ) (boolean\ mask) (boolean mask) 索引返回一个一维张量。 m a s k e d _ s e l e c t masked\_select masked_select 索引是在原来 t e n s o r tensor tensor 的 s h a p e shape shape 基础上打平,然后在打平后的 t e n s o r tensor tensor 上进行索引。返回的张量不与原始张量共享内存空间。
import torch from IPython.core.interactiveshell import InteractiveShell InteractiveShell.ast_node_interactivity = "all" r = torch.randn(3, 4) r # 生成r这个tensor中大于0.5的元素的掩码 mask = r.ge(0.5) mask # 取出r这个Tensor中大于0.5的元素 mr = torch.masked_select(r, mask) mr mr.shape
5. take索引
torch.take(input, index)
: t a k e take take 索引是在原来 t e n s o r tensor tensor 的 s h a p e shape shape 基础上打平,然后按照 i n d e x index index 在打平后的 t e n s o r tensor tensor 上索取对应位置的元素。import torch from IPython.core.interactiveshell import InteractiveShell InteractiveShell.ast_node_interactivity = "all" src = torch.randn(3, 4) src torch.take(src, torch.tensor([1,3,5,7,9,11]))
6. 使用 … 索引任意多的维度
import torch
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
photo = torch.rand(7,3,28,28)
photo.shape
# 等于photo.shape
photo[...].shape
# 第一张图片的所有维度
photo[0,...].shape
# 所有图片第二通道的所有维度
photo[:,1,...].shape
# 所有图像所有通道第一、第二行的所有列
photo[...,:2,:].shape
# 所有图像所有通道所有行的第一、第二列
photo[...,:2].shape
7. gather 函数
torch.gather(input, dim, index, out=None)
:将输入 i n p u t input input 张量按照指定的维度 d i m dim dim 和设定的索引 i n d e x ( L o n g T e n s o r ) index(LongTensor) index(LongTensor) 取值,返回一个新的 t e n s o r tensor tensor。 i n d e x ( L o n g T e n s o r ) index(LongTensor) index(LongTensor) 的维度要与 i n p u t input input 维度相同。 d i m dim dim 为从左到右的维度,将 i n d e x ( L o n g T e n s o r ) index(LongTensor) index(LongTensor) 的索引的 d i m dim dim 维度值修改为 i n d e x index index 中的值,然后到 i n p u t input input 张量中进行索引,就生成了 g a t h e r gather gather 返回的 t e n s o r tensor tensor。
8. where 函数
torch.where(condition, x, y)
:根据 c o n d i t i o n condition condition 合并两个 t e n s o r tensor tensor 类型,按照一定的规则从 x , y x,y x,y 中选择元素组成新的张量, c o n d i t i o n condition condition 是规则, x x x 和 y y y 是同 s h a p e shape shape 的矩阵。如果满足条件,则返回 x x x 中元素。若不满足,返回 y y y 中元素。
8. tril & triu & diag 函数
-
torch.tril(input, diagonal=0, out=None)
:返回一个张量,包含输入张量( i n p u t input input)的下三角部分,其余部分设为 0 0 0,参数 d i a g o n a l diagonal diagonal 控制对角线。- 如果 d i a g o n a l diagonal diagonal 为空,输入矩阵保留 主对角线 及 主对角线下方 的所有元素;
- 如果 d i a g o n a l diagonal diagonal 为正数 n n n,输入矩阵保留主对角线上方第 n \pmb{n} nnn 条对角线 和 主对角线上方第 n \pmb{n} nnn 条对角线以下 的所有元素;
- 如果
d
i
a
g
o
n
a
l
diagonal
diagonal 为负数
−
n
-n
−n,输入矩阵保留 主对角线下方第
n
\pmb{n}
nnn 条对角线 及 主对角线下方第
n
\pmb{n}
nnn 条对角线以下 的元素;
-
torch.triu(input, diagonal=0, out=None)
:返回一个张量,包含输入张量( i n p u t input input)的上三角部分,其余部分设为 0 0 0,参数 d i a g o n a l diagonal diagonal 控制对角线。- 如果 d i a g o n a l diagonal diagonal 为空,输入矩阵保留 主对角线 及 主对角线上方 的所有元素;
- 如果 d i a g o n a l diagonal diagonal 为正数 n n n,输入矩阵保留主对角线上方第 n \pmb{n} nnn 条对角线 和 主对角线上方第 n \pmb{n} nnn 条对角线以上 的所有元素;
- 如果
d
i
a
g
o
n
a
l
diagonal
diagonal 为负数
−
n
-n
−n,输入矩阵保留 主对角线下方第
n
\pmb{n}
nnn 条对角线 及 主对角线下方第
n
\pmb{n}
nnn 条对角线以上 的元素;
-
torch.diag(input, diagonal=0, out=None)
:如果输入是一个向量,则返回一个以 i n p u t input input 为对角线元素的 2 d 2d 2d 方阵;如果输入是一个矩阵,则返回一个将 i n p u t input input 第 n n n 条对角线作为元素的 1 d 1d 1d 张量。- d i a g o n a l = 0 diagonal = 0 diagonal=0,主对角线。
- d i a g o n a l > 0 diagonal > 0 diagonal>0,主对角线之上。
-
d
i
a
g
o
n
a
l
<
0
diagonal < 0
diagonal<0,主对角线之下。