PyTorch | 索引与切片

1. 普通索引

import torch
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

photo = torch.rand(4,3,28,28)
photo.shape
photo[2].shape # 索引到第一个维度
photo[3,2].shape # 索引到第二个维度
photo[1,2,3].shape # 索引到第三个维度
photo[1,2,3,4] # 索引到具体元素,返回一个标量

在这里插入图片描述

2. 冒号索引(切片)

import torch
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

photo = torch.rand(4,3,28,28)
photo.shape
# 索引到第一维度0,1(不包含2)
photo[:2].shape 
# 索引到第一维度0,1(不包含2);第二维度2
photo[:2, 2:].shape 
# 索引到第一个维度0,1(不包含2);第二个维度1,2;第三、四维度所有
photo[:2, 1:, :, :].shape 
# 索引到第一个维度0,1(不包含2);第二个维度2;第三维度从倒数第三到最后;第四维度所有
photo[:2, 2:, -3:, :].shape
# 索引到第一个维度0,1(不包含2);第二个维度2;第三维度从0到27,每两步取一个元素;第四维度所有
photo[:2, 2:, 0:28:2, :].shape 

在这里插入图片描述

3. index_select 选择特定索引

  • torch.index_select(input, dim, index):沿指定维度 d i m dim dim 对输入进行切片,取 i n d e x index index 中指定的相应项,然后返回一个新的张量,返回的张量与原始张量有相同的维度(在指定轴上),返回的张量与原始张量不共享内存空间。

    下面 p h o t o photo photo 这个 t e n s o r tensor tensor 可以看作 7 7 7 R G B ( 3 通 道 ) RGB(3通道) RGB(3) M N I S T MNIST MNIST 图像,长宽都是 28 p x 28px 28px。那么在第一维度上可以选择特定的图片,在第二维度上选择特定的通道,在第三维度上选择特定的行,在第四维度上选择特定的列等。

    import torch
    from IPython.core.interactiveshell import InteractiveShell
    InteractiveShell.ast_node_interactivity = "all"
    
    photo = torch.rand(7,3,28,28)
    photo.shape
    # 选择第一张和第三张图
    photo.index_select(0, torch.tensor([0, 2])).shape
    # 选择R通道和B通道
    photo.index_select(1, torch.tensor([0, 2])).shape
    # 选择图像的0~8行
    photo.index_select(2, torch.arange(8)).shape
    # 选择图像的0~8列
    photo.index_select(3, torch.arange(8)).shape
    

4. masked_select 选择符合条件的索引

  • torch.masked_select(input,mask):根据一个布尔掩码 ( b o o l e a n   m a s k ) (boolean\ mask) (boolean mask) 索引返回一个一维张量。 m a s k e d _ s e l e c t masked\_select masked_select 索引是在原来 t e n s o r tensor tensor s h a p e shape shape 基础上打平,然后在打平后的 t e n s o r tensor tensor 上进行索引。返回的张量不与原始张量共享内存空间。
    在这里插入图片描述
    import torch
    from IPython.core.interactiveshell import InteractiveShell
    InteractiveShell.ast_node_interactivity = "all"
     
    r = torch.randn(3, 4)
    r
     
    # 生成r这个tensor中大于0.5的元素的掩码
    mask = r.ge(0.5)
    mask
     
    # 取出r这个Tensor中大于0.5的元素
    mr = torch.masked_select(r, mask)
    mr
    mr.shape
    
    在这里插入图片描述

5. take索引

  • torch.take(input, index) t a k e take take 索引是在原来 t e n s o r tensor tensor s h a p e shape shape 基础上打平,然后按照 i n d e x index index 在打平后的 t e n s o r tensor tensor 上索取对应位置的元素。
    import torch
    from IPython.core.interactiveshell import InteractiveShell
    InteractiveShell.ast_node_interactivity = "all"
     
    src = torch.randn(3, 4)
    src
    torch.take(src, torch.tensor([1,3,5,7,9,11]))
    
    在这里插入图片描述

6. 使用 … 索引任意多的维度

import torch
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

photo = torch.rand(7,3,28,28)

photo.shape
# 等于photo.shape
photo[...].shape
# 第一张图片的所有维度
photo[0,...].shape
# 所有图片第二通道的所有维度
photo[:,1,...].shape
# 所有图像所有通道第一、第二行的所有列
photo[...,:2,:].shape
# 所有图像所有通道所有行的第一、第二列
photo[...,:2].shape

在这里插入图片描述

7. gather 函数

  • torch.gather(input, dim, index, out=None):将输入 i n p u t input input 张量按照指定的维度 d i m dim dim 和设定的索引 i n d e x ( L o n g T e n s o r ) index(LongTensor) index(LongTensor) 取值,返回一个新的 t e n s o r tensor tensor i n d e x ( L o n g T e n s o r ) index(LongTensor) index(LongTensor) 的维度要与 i n p u t input input 维度相同。 d i m dim dim 为从左到右的维度,将 i n d e x ( L o n g T e n s o r ) index(LongTensor) index(LongTensor) 的索引的 d i m dim dim 维度值修改为 i n d e x index index 中的值,然后到 i n p u t input input 张量中进行索引,就生成了 g a t h e r gather gather 返回的 t e n s o r tensor tensor
    在这里插入图片描述

8. where 函数

  • torch.where(condition, x, y):根据 c o n d i t i o n condition condition 合并两个 t e n s o r tensor tensor 类型,按照一定的规则从 x , y x,y x,y 中选择元素组成新的张量, c o n d i t i o n condition condition 是规则, x x x y y y 是同 s h a p e shape shape 的矩阵。如果满足条件,则返回 x x x 中元素。若不满足,返回 y y y 中元素。
    在这里插入图片描述

8. tril & triu & diag 函数

  • torch.tril(input, diagonal=0, out=None):返回一个张量,包含输入张量( i n p u t input input)的下三角部分,其余部分设为 0 0 0,参数 d i a g o n a l diagonal diagonal 控制对角线。

    • 如果 d i a g o n a l diagonal diagonal 为空,输入矩阵保留 主对角线主对角线下方 的所有元素;
    • 如果 d i a g o n a l diagonal diagonal 为正数 n n n,输入矩阵保留主对角线上方第 n \pmb{n} nnn 条对角线主对角线上方第 n \pmb{n} nnn 条对角线以下 的所有元素;
    • 如果 d i a g o n a l diagonal diagonal 为负数 − n -n n,输入矩阵保留 主对角线下方第 n \pmb{n} nnn 条对角线主对角线下方第 n \pmb{n} nnn 条对角线以下 的元素;
      在这里插入图片描述
  • torch.triu(input, diagonal=0, out=None):返回一个张量,包含输入张量( i n p u t input input)的上三角部分,其余部分设为 0 0 0,参数 d i a g o n a l diagonal diagonal 控制对角线。

    • 如果 d i a g o n a l diagonal diagonal 为空,输入矩阵保留 主对角线主对角线上方 的所有元素;
    • 如果 d i a g o n a l diagonal diagonal 为正数 n n n,输入矩阵保留主对角线上方第 n \pmb{n} nnn 条对角线主对角线上方第 n \pmb{n} nnn 条对角线以上 的所有元素;
    • 如果 d i a g o n a l diagonal diagonal 为负数 − n -n n,输入矩阵保留 主对角线下方第 n \pmb{n} nnn 条对角线主对角线下方第 n \pmb{n} nnn 条对角线以上 的元素;
      在这里插入图片描述
  • torch.diag(input, diagonal=0, out=None):如果输入是一个向量,则返回一个以 i n p u t input input对角线元素 2 d 2d 2d 方阵;如果输入是一个矩阵,则返回一个将 i n p u t input input n n n 条对角线作为元素的 1 d 1d 1d 张量。

    • d i a g o n a l = 0 diagonal = 0 diagonal=0,主对角线。
    • d i a g o n a l > 0 diagonal > 0 diagonal>0,主对角线之上。
    • d i a g o n a l < 0 diagonal < 0 diagonal<0,主对角线之下。
      在这里插入图片描述
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

软耳朵DONG

觉得文章不错就鼓励一下作者吧

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值