一、Tensor的索引和切片
1、什么是索引:
- Tensor的索引是指得到tensor中元素的唯一标识符。
- Tensor的索引是用来标识多维数组中各个元素的位置和大小等信息。在PyTorch等框架中,张量(Tensor)是用于存储任意类型数据的多维数组,我们可以通过索引来访问和操作其中的元素。索引可以是单个数字,也可以是数字组成的序列,甚至是其他张量,这后者的情况被称为花式索引或高级索引。通过索引我们能精确地选取到想要操作的数据项,这对于数据访问和算法实现至关重要。
2、什么是切片:
-
Tensor的切片则是指从一个tensor中选择特定的元素或子集的操作。
-
切片与索引类似,都是用来获取特定元素的手段,但切片通常指的是按照一定的连续范围来获取数据。在PyTorch等框架中,可以通过切片的方式获取tensor中的某一段数据。基本切片操作可以指定开始、结束和步长来获取所需的部分数据。例如,
a[:2].shape
将会得到张量a
的前两个元素所形成的新张量的形状。除了基本切片以外,还可以利用函数如index_select()
来进行不规则间隔的索引,即选取指定维度上的特定元素形成新的张量。
3、Tensor支持件基本索引和切片操作,不仅如此,它还支持ndarray中的高级索引(整数索引和布尔索引)操作。
# 导入PyTorch库
import torch
import numpy as np
# 创建一个3x3的张量,并打印出第三行的所有元素。
# 创建一个包含从0到8的整数的一维张量,并使用view()函数将其重塑为一个3x3的二维张量。这个张量被赋值给变量a。
a = torch.arange(9).view(3, 3)
# 基本索引
print(a[2]) # 打印变量a中索引为2的元素。由于a是一个二维张量,索引2表示第三行的所有元素。
# 切片
print(a[1:, :-1]) # a[1:, :-1]表示获取从索引为1的元素开始到倒数第二个元素的所有元素。
# a[1:]表示从索引为1的元素开始到末尾的所有元素,[:-1]表示从开头到倒数第二个元素的所有元素。
# 带步长的切片(PyTorch现在不支持负步长)
print(a[::2]) # a[::2]表示获取从开头到末尾,每隔一个元素的子序列,每隔一个元素取一个元素。
# 整数索引
rows = [0, 1] # 定义了一个名为rows的列表,其中包含两个整数元素0和1。
cols = [2, 2] # 定义了一个名为cols的列表,其中包含两个整数元素2和2
print(a[rows, cols]) # 使用整数索引来获取数组或矩阵a中指定位置的元素。
# a[rows, cols]表示获取位于第0行第2列和第1行第2列的元素。然后,通过print()函数将这些元素打印输出
# 布尔索引
index = a > 4 # 定义了一个名为index的变量,其中包含一个布尔类型的张量。这个张量表示数组或矩阵a中每个元素是否大于4。
print(index) # 打印输出布尔索引的结果,即index的值
print(a[index]) # 使用布尔索引来获取数组或矩阵a中满足条件的元素
# a[index]表示获取所有大于4的元素。然后,通过print()函数将这些元素打印输出。
4、torch.nonzero用于返回非零值的索引矩阵。
a = torch.arange(9).view(3, 3) # 创建一个从0到8的整数序列,并将其重塑为一个3x3的二维张量。这个张量被赋值给变量a。
index = torch.nonzero(a >= 8) # 定义了一个名为index的变量,其中包含一个整数类型的张量。这个张量表示数组或矩阵a中每个元素是否大于等于8。
print(index) # 打印输出非零元素的索引,即index的值。
a = torch.randint(0, 2, (3, 3)) # 创建一个随机整数类型的3x3张量,其值在0和1之间(包括0和1)。这个张量被赋值给变量a。
print(a)
index = torch.nonzero(a) # 定义了一个名为index的变量,其中包含一个整数类型的张量。这个张量表示数组或矩阵a中非零元素的索引。
print(index)
5、torch.where(condition,x,y)判断condition的条件是否满足。当某个元素满足条件时,则返回对应矩阵x相同的位置的元素,否则返回矩阵y的元素。
x = torch.randn(3, 2) # 创建了一个形状为(3, 2)的随机张量,并将其赋值给变量x。这个张量的元素是从标准正态分布中随机生成的。
y = torch.ones(3, 2) # 创建了一个形状为(3, 2)的全1张量,并将其赋值给变量y
print(x)
print(torch.where(x > 0, x, y)) # 使用torch.where函数根据条件选择元素。如果x中的元素大于0,则选择对应的x中的元素,否则选择对应的y中的元素。
6、矩阵与矩阵乘法 的理解。
-
矩阵可以看作是一个由数字排列成的矩形阵列,其中m行n列的数字构成一个m×n的矩阵。在几何意义上,可以将矩阵视作线性变换,例如在图形处理中,它可以用来表示旋转、缩放等操作。
-
其次,矩阵乘法是两个矩阵相乘得到另一个矩阵的过程。其基本规则是第一个矩阵的列数必须和第二个矩阵的行数相同。矩阵乘积的结果是一个新矩阵,其行数来自第一个矩阵,列数来自第二个矩阵。具体到数学公式上,若有两个矩阵A和B,A是一个m×n的矩阵,B是一个n×k的矩阵,那么它们的乘积AB将是一个m×k的矩阵,并且AB中的每个元素是通过对应的行与列进行乘积求和得到的。