PyTorch | 索引与切片

本文详细介绍了PyTorch中不同类型的索引和切片操作,包括普通索引、冒号索引、index_select、masked_select、take、...索引、gather函数以及where、tril、triu和diag函数的用法。通过实例展示了如何在多维张量中选取特定元素或满足特定条件的元素,帮助读者深入理解PyTorch张量操作。
摘要由CSDN通过智能技术生成

1. 普通索引

import torch
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

photo = torch.rand(4,3,28,28)
photo.shape
photo[2].shape # 索引到第一个维度
photo[3,2].shape # 索引到第二个维度
photo[1,2,3].shape # 索引到第三个维度
photo[1,2,3,4] # 索引到具体元素,返回一个标量

在这里插入图片描述

2. 冒号索引(切片)

import torch
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

photo = torch.rand(4,3,28,28)
photo.shape
# 索引到第一维度0,1(不包含2)
photo[:2].shape 
# 索引到第一维度0,1(不包含2);第二维度2
photo[:2, 2:].shape 
# 索引到第一个维度0,1(不包含2);第二个维度1,2;第三、四维度所有
photo[:2, 1:, :, :].shape 
# 索引到第一个维度0,1(不包含2);第二个维度2;第三维度从倒数第三到最后;第四维度所有
photo[:2, 2:, -3:, :].shape
# 索引到第一个维度0,1(不包含2);第二个维度2;第三维度从0到27,每两步取一个元素;第四维度所有
photo[:2, 2:, 0:28:2, :].shape 

在这里插入图片描述

3. index_select 选择特定索引

  • torch.index_select(input, dim, index):沿指定维度 d i m dim dim 对输入进行切片,取 i n d e x index index 中指定的相应项,然后返回一个新的张量,返回的张量与原始张量有相同的维度(在指定轴上),返回的张量与原始张量不共享内存空间。

    下面 p h o t o photo photo 这个 t e n s o r tensor tensor 可以看作 7 7 7 R G B ( 3 通 道 ) RGB(3通道) RGB(3) M N I S T MNIST MNIST 图像,长宽都是 28 p x 28px 28px。那么在第一维度上可以选择特定的图片,在第二维度上选择特定的通道,在第三维度上选择特定的行,在第四维度上选择特定的列等。

    import torch
    from IPython.core.interactiveshell import InteractiveShell
    InteractiveShell.ast_node_interactivity = "all"
    
    photo = torch.rand(7,3,28,28)
    photo.shape
    # 选择第一张和第三张图
    photo.index_select(0, torch.tensor([0, 2])).shape
    # 选择R通道和B通道
    photo.index_select(1, torch.tensor([0, 2])).shape
    # 选择图像的0~8行
    photo.index_select(2, torch.arange(8)).shape
    # 选择图像的0~8列
    photo.index_select(3, torch.arange(8)).shape
    

4. masked_select 选择符合条件的索引

  • torch.masked_select(input,mask):根据一个布尔掩码 ( b o o l e a n   m a s k ) (boolean\ mask) (boolean mask) 索引返回一个一维张量。 m a s k e d _ s e l e c t masked\_select masked_select 索引是在原来 t e n s o r tensor tensor s h a p e shape shape 基础上打平,然后在打平后的 t e n s o r tensor tensor 上进行索引。返回的张量不与原始张量共享内存空间。
    在这里插入图片描述
    import torch
    from IPython.core.interactiveshell import InteractiveShell
    InteractiveShell.ast_node_interactivity = "all"
     
    r = torch.randn(3, 4)
    r
     
    # 生成r这个tensor中大于0.5的元素的掩码
    mask = r.ge(0.5)
    mask
     
    # 取出r这个Tensor中大于0.5的元素
    mr = torch.masked_select(r, mask)
    mr
    mr.shape
    
    在这里插入图片描述

5. take索引

  • torch.take(input, index) t a k e take take 索引是在原来 t e n s o r tensor tensor s h a p e shape shape 基础上打平,然后按照 i n d e x index index 在打平后的 t e n s o r tensor tensor 上索取对应位置的元素。
    import torch
    from IPython.core.interactiveshell import InteractiveShell
    InteractiveShell.ast_node_interactivity = "all"
     
    src = torch.randn(3, 4)
    src
    torch.take(src, torch.tensor([1,3,5,7,9,11]))
    
    在这里插入图片描述

6. 使用 … 索引任意多的维度

import torch
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

photo = torch.rand(7,3,28,28)

photo.shape
# 等于photo.shape
photo[...].shape
# 第一张图片的所有维度
photo[0,...].shape
# 所有图片第二通道的所有维度
photo[:,1,...].shape
# 所有图像所有通道第一、第二行的所有列
photo[...,:2,:].shape
# 所有图像所有通道所有行的第一、第二列
photo[...,:2].shape

在这里插入图片描述

7. gather 函数

  • torch.gather(input, dim, index, out=None):将输入 i n p u t input input 张量按照指定的维度 d i m dim dim 和设定的索引 i n d e x ( L o n g T e n s o r ) index(LongTensor) index(LongTensor) 取值,返回一个新的 t e n s o r tensor tensor i n d e x ( L o n g T e n s o r ) index(LongTensor) index(LongTensor) 的维度要与 i n p u t input input 维度相同。 d i m dim dim 为从左到右的维度,将 i n d e x ( L o n g T e n s o r ) index(LongTensor) index(LongTensor) 的索引的 d i m dim dim 维度值修改为 i n d e x index index 中的值,然后到 i n p u t input input 张量中进行索引,就生成了 g a t h e r gather gather 返回的 t e n s o r tensor tensor
    在这里插入图片描述

8. where 函数

  • torch.where(condition, x, y):根据 c o n d i t i o n condition condition 合并两个 t e n s o r tensor tensor 类型,按照一定的规则从 x , y x,y x,y 中选择元素组成新的张量, c o n d i t i o n condition condition 是规则, x x x y y y 是同 s h a p e shape shape 的矩阵。如果满足条件,则返回 x x x 中元素。若不满足,返回 y y y 中元素。
    在这里插入图片描述

8. tril & triu & diag 函数

  • torch.tril(input, diagonal=0, out=None):返回一个张量,包含输入张量( i n p u t input input)的下三角部分,其余部分设为 0 0 0,参数 d i a g o n a l diagonal diagonal 控制对角线。

    • 如果 d i a g o n a l diagonal diagonal 为空,输入矩阵保留 主对角线主对角线下方 的所有元素;
    • 如果 d i a g o n a l diagonal diagonal 为正数 n n n,输入矩阵保留主对角线上方第 n \pmb{n} nnn 条对角线主对角线上方第 n \pmb{n} nnn 条对角线以下 的所有元素;
    • 如果 d i a g o n a l diagonal diagonal 为负数 − n -n n,输入矩阵保留 主对角线下方第 n \pmb{n} nnn 条对角线主对角线下方第 n \pmb{n} nnn 条对角线以下 的元素;
      在这里插入图片描述
  • torch.triu(input, diagonal=0, out=None):返回一个张量,包含输入张量( i n p u t input input)的上三角部分,其余部分设为 0 0 0,参数 d i a g o n a l diagonal diagonal 控制对角线。

    • 如果 d i a g o n a l diagonal diagonal 为空,输入矩阵保留 主对角线主对角线上方 的所有元素;
    • 如果 d i a g o n a l diagonal diagonal 为正数 n n n,输入矩阵保留主对角线上方第 n \pmb{n} nnn 条对角线主对角线上方第 n \pmb{n} nnn 条对角线以上 的所有元素;
    • 如果 d i a g o n a l diagonal diagonal 为负数 − n -n n,输入矩阵保留 主对角线下方第 n \pmb{n} nnn 条对角线主对角线下方第 n \pmb{n} nnn 条对角线以上 的元素;
      在这里插入图片描述
  • torch.diag(input, diagonal=0, out=None):如果输入是一个向量,则返回一个以 i n p u t input input对角线元素 2 d 2d 2d 方阵;如果输入是一个矩阵,则返回一个将 i n p u t input input n n n 条对角线作为元素的 1 d 1d 1d 张量。

    • d i a g o n a l = 0 diagonal = 0 diagonal=0,主对角线。
    • d i a g o n a l > 0 diagonal > 0 diagonal>0,主对角线之上。
    • d i a g o n a l < 0 diagonal < 0 diagonal<0,主对角线之下。
      在这里插入图片描述
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

软耳朵DONG

觉得文章不错就鼓励一下作者吧

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值