目录
一、tensor对象及其运用
Tensor对象是一个任意维度的同类型元素的矩阵。
在深度学习和多种数值计算中,Tensor对象是非常基础且重要的数据结构。它是一个可以包含任意维度的数据集合,但所有的元素必须是相同的数据类型。这种数据结构非常灵活,可以表示标量(零维)、向量(一维)、矩阵(二维)以及更高维度的数组。Tensor支持定义在不同的设备上,如CPU或GPU,使得大量计算能够在GPU上并行执行,从而大幅度提高运算效率。
在实践中,运用Tensor对象主要涉及以下几个方面:
- 创建Tensor:可以通过多种方式创建Tensor对象,例如使用列表、元组或者Numpy数组进行构造。这使得从各种数据源创建Tensor变得非常方便。
- 指定数据类型和设备:通过dtype属性来指定Tensor的数据类型,而device属性用于确定Tensor存储在哪个设备上(CPU或GPU)。
#通过device指定设备 cuda0 = torch.device('cuda:0') c = torch.ones((2,2)) print(c)
- Tensor的操作:Tensor对象支持各种算术运算,包括逐个元素的运算(如相加、相乘等),以及更复杂的操作比如矩阵乘法。熟练地运用这些操作对于深度学习模型的开发至关重要。
a =torch.tensor([[1,2],[3,4]]) b =torch.tensor([[1,2],[3,4]]) c = a * b print('逐元素相乘:', c) c = torch.mm(a,b) print('矩阵相乘:',c)
- 数据转移:在CPU和GPU之间转移数据也是Tensor操作的一部分。这通常发生在需要将数据送入GPU进行加速计算时。
- 索引和切片:与Numpy类似,Tensor也支持索引和切片操作,这使得对高维数据的访问和处理更加便捷。下面内容会做详细解释。
- 类型转换和复制:Tensor提供了接口以方便地进行类型转换和数据复制,这对于确保数据在不同计算阶段保持正确格式非常有用。
二、tensor的索引和切片
Tensor的索引和切片操作允许我们访问和修改张量(多维数组)中的特定数据,类似于NumPy中的操作。以下是关于Tensor索引和切片的详细信息:
- 基本索引:可以通过整数索引直接访问Tensor中的具体元素。例如,对于一个一维Tensor
a = torch.tensor([0, 1, 2, 3, 4])
,a[1]
将返回元素 1。a = torch.arange(9).view(3,3) print(a[2,2])
- 切片操作:在对Tensor进行切片时,使用
start:end
的形式,其中start
是起始索引,end
是结束索引。需要注意的是,起始索引是包含在切片中的,而结束索引是不包含在切片中的。例如,tensor[1:3]
将返回索引为1和2的元素,但不包括索引为3的元素。a = torch.arange(9).view(3,3) print(a[1:,:-1])
- 负数索引:可以使用负数索引来从后面开始对Tensor进行切片。例如,
tensor[-1]
将返回最后一个元素,tensor[-2:]
将返回倒数第二个元素到最后一个元素之间的所有元素。 - 高级索引:除了基本的索引和切片,Tensor还支持更复杂的索引方式,如整数数组索引,这允许按照给定的整数数组来获取Tensor中的元素。
#整数索引 rows = [0,1] cols = [2,2] print(a[rows,cols])
- take函数:
torch.take
函数可以用于根据提供的索引数组来选择元素。它会将输入的Tensor打平后,根据索引选择元素。例如,torch.take(src, torch.tensor([0, 2, 8]))
将返回索引0、2和8对应的元素。 - 多维索引:对于多维Tensor,可以通过指定多个整数索引来访问特定位置的元素,如
a[1][2]
表示访问二维Tensor中第1行第2列的元素。 - 步长:在进行切片操作时,还可以指定步长,例如
a[::2]
表示从Tensora
中每隔一个元素取一个,即取所有偶数索引上的元素。a = torch.arange(9).view(3,3) print(a[::2])
三、tensor的变换、拼接和拆分
1、tensor.nelement、tensor.ndimension、ndimension.size(tensor.shape)
可分别用来查看矩阵元素的个数、轴的个数以及维度
import torch
a= torch.rand(1,2,3,4,5)
print('元素个数:',a.nelement())
print('轴的个数:',a.ndimension())
print('矩阵维度:',a.size(),a.shape)
2、在pytorch中,tensor.view和tensor.reshape都可以用来更改tensor的维度,tensor.view返回的一定是一个索引,原始值和返回值都被更改,tensor.reshape返回的是引用还是复制时不确定的
b = a.view(2*3,4*5)
print(b.shape)
c = a.reshape(-1)
print(c.shape)
d = a.reshape(2*3,-1)
print(d.shape)
3、torch.squeeze 和 torch.unsqueeze用于为Tensor去掉和添加轴。其中torch.squeeze用于去维度为1的轴,而torch.unsqueeze用于给Tensor 的指定位置添加一个维度为1的轴。
b = torch.squeeze(a)
print(b.shape)
print(torch.unsqueeze(b,0).shape)
4、torch.t 和 torch.transpose用于转置二维矩阵。这两个函数只接收二维Tensor,torch.t 是 torch.transpose 的简化版。
a = torch.tensor([[2]])
b = torch.tensor([[2,3]])
print(torch.transpose(a,1,0,))
print(torch.t(a))
print(torch.transpose(b,1,0,))
print(torch.t(b))
5、对于高维度Tensor,可以使用permute方法来变换维度。
a = torch.rand(1,224,224,3)
print(a.shape)
b = a.permute(0,3,1,2)
print(b.shape)
6、PyTorch 提供了torch.cat 和torch.stack用于拼接矩阵。
a = torch.randn(2,3)
b = torch.randn(3,3)
c = torch.cat((a,b))
d = torch.cat((b,b,b),dim = 1)
print(c.shape)
print(d.shape)
c = torch.stack((b,b),dim = 1)
d = torch.stack((b,b),dim = 0)
print(c.shape)
print(d.shape)
7、除了拼接矩阵,PyTorch 还提供了 torch.split和 torch.chui nk用于拆分矩阵。它们的不同之处在于,torch.split 传入的是拆分后每个矩阵的大小,可以传入 li ist,也可以传入整数,而torch.chunk传入的是拆分的矩阵个数。
a = torch.randn(10,3)
for x in torch.split(a,[1,2,3,4],dim=0):
print(x.shape)
for x in torch.split(a,4,dim=0):
print(x.shape)
for x in torch.chunk(a,4,dim=0):
print(x.shape)