一、概念
pytorch框架是Python中对张量进行处理的包,提供各种模块实现各种功能,其中数据是以张量类型存储的。
二、创建张量
1.基本创建方式:(主要分为小写和大写两种)
torch.tensor(data=,dtype=),data指定数据,dtype指定类型,无法指定形状
import torch
import numpy as np
print(torch.tensor(10))
print(torch.tensor([10]).ndim)
print(torch.tensor([[10,20],[40,50]],dtype=torch.float))
torch.Tensor(data=, size=()),既能指定数据,又能指定形状
print(torch.Tensor(5,3)) #默认指定形状
print(torch.Tensor(size=(2,3))) #指定形状
print(torch.Tensor([10]).ndim)
print(torch.Tensor([[10]]).ndim)
# 通过列表创建张量
print(torch.Tensor([[10, 20], [40, 50]]), torch.tensor([[10, 20], [40, 50]]).dtype)
# 通过numpy创建张量
print(torch.Tensor(np.array([10,20,30])))
torch.IntTensor()/FloatTensor()...,Tensor直接指定类型创建
print(torch.ShortTensor([[10, 20]]))
print(torch.DoubleTensor([[10, 20]]))
2.线性和随机张量
线性张量:
# arange() 包左不包右
print(torch.arange(2, 10))
print(torch.arange(2, 10, 2))
# linspace() 包左包右
print(torch.linspace(2, 10, 5))
随机张量:
print(torch.rand(2, 3)) # rand随机生成0-1的浮点数张量
print(torch.randn(2, 3)) # randn随机生成正态分布的浮点数张量
print(torch.randint(10, (2, 3))) # randint随机生成整数张量
随机种子的设置和获取:
#设置随机数种子
print(torch.manual_seed(666))
#获取随机数种子
print(torch.initial_seed())
3.0/1/指定值张量
import torch
# zeros(size) 创建全0张量
print(torch.zeros(2, 3))
# zeros_like(张量) 模仿指定张量的形状创建全0张量
x = torch.randint(0, 10, (3, 3))
print(torch.zeros_like(x))
# ones(size) 创建全1张量
print(torch.ones(2, 3))
# ones_like(张量) 模仿指定张量的形状创建全1张量
x = torch.randint(0, 10, (3, 3))
print(torch.ones_like(x))
# full(size,value) 创建指定值张量
print(torch.full((2, 3), 8))
# full_like(张量) 模仿指定张量的形状创建全指定值张量
x = torch.randint(0, 10, (3, 3))
print(torch.full_like(x, 9))
4.张量中数据类型以及转换操作
①张量.类型函数():
data = torch.randint(0, 10, [2, 5])
print(data.short())
print(data.int())
print(data.long())
print(data.half())
print(data.float())
print(data.double())
②张量.type(指定类型)
类型转换方式1 torch.小写类型名
print(data.type(torch.int))
print(data.type(torch.long), data.type(torch.long).dtype)
print(data.type(torch.half))
print(data.type(torch.float), data.type(torch.float).dtype)
类型转换方式2 torch.int位数/torch.float位数
print(data.type(torch.int16))
print(data.type(torch.int32))
print(data.type(torch.int64))
print(data.type(torch.float16))
print(data.type(torch.float32))
print(data.type(torch.float64))
类型转换方式3 torch.大写类型Tensor
#会有警告
print(data.type(torch.ShortTensor))
print(data.type(torch.DoubleTensor))
三、tensor和numpy的相互转换
①numpy数组转换成张量
# TODO from_numpy()将numpy数组转换为张量,但是共享内存
n1 = np.array([1, 2, 3])
t1 = torch.from_numpy(n1)
print(n1, type(n1))
print(t1, type(t1))
# 演示from_numpy()结果共享内存
n1[0] = 100
print(n1, id(n1))
print(t1, id(t1))
# TODO torch.tensor(ndarray)将numpy数组转换为张量,不会共享内存
# 创建numpy数组
n2 = np.array([1, 2, 3])
t2 = torch.tensor(n2)
print(n2, type(n2))
print(t2, type(t2))
# 演示tensor(ndarray)结果不共享内存
n2[0] = 200
print(n2, id(n2))
print(t2, id(t2))
②张量转换成numpy数组
# TODO numpy()将张量转换为numpy数组
t3 = torch.tensor([1, 2, 3])
n3 = t3.numpy()
print(t3, type(t3))
print(n3, type(n3))
# 演示numpy()结果共享内存
t3[0] = 300
print(t3, id(t3))
print(n3, id(n3))
# TODO numpy().copy()将张量转换为numpy数组
t4 = torch.tensor([1, 2, 3])
n4 = t4.numpy().copy()
print(t4, type(t4))
print(n4, type(n4))
# 演示numpy().copy()结果不共享内存
t4[0] = 300
print(t4, id(t4))
print(n4, id(n4))
③仅有一个元素的张量和标量互转
# 标量转张量
a1 = 10
t1 = torch.tensor(a1)
print(a1, type(a1))
print(t1, type(t1))
print('------------------')
# 张量转标量
a2 = t1.item()
print(t1, type(t1))
print(a2, type(a2))
四、张量的运算
1.基础运算
+:add() -\:sub() *:mul() /:div() 负号:neg、neg_()
2.矩阵乘法运pow算
符号:@,算法为矩阵点乘的计算方法
要求:A@B,其中必须A的列=B的行
api:torch.matmul()
3.运算函数
max,min,mean,sqrt(),
sum(dim=)设置dim参数对应维度元素计算,否则默认所有元素,
log()默认以e为底,
pow(exponent=)幂次方,exp()指数
注意:张量中也存在dot()函数,但是和Numpy中有所区别的是,在张量中它只能对一维张量进行操作。
五、张量索引操作
1.核心
作用:根据索引获取对应位置的数据
格式:张量[行,列]
2.获取单独行(列)
#获取第二行的数据
print(data[1, :])
#获取第二列的数据
print(data[:,1])
3.获取列表指定多行多列
#获取张量中第一行和第三行的数据
print(data[[0,2],:])
#获取第一列和第三列的数据
print(data[:,[0,2]])
# TODO 获取张量中第一行和第三行中第一列和第三列的数据(易错)
#先写错误示范(输出第一行一列和第三行三列这两个数)
print(data[[0,2],[0,2]])
#下面写正确的示范
print(data[[[0],[2]],[0,2]])
4.获取切片指定连续行列
#获取第一行到第三行的数据
print(data[0:3,:])
#获取第一列到第三列的数据
print(data[:,0:3])
5.获取布尔值为Ture行列
import torch
#举例:有一张量第一行为[1,2,6,7,5]
data = torch.tensor([[1,2,6,7,5],[2,4,6,8,9]])
#判断第一行哪些数据>5,是输出Ture,否则输出False
print(data[0] > 5)
#根据布尔索引获取对应列的数据
print(data[:, torch.tensor([False, False, True, True, False])])
print(data[:,data[0]>5])
6.多维索引
从上图可知,以O点为起点分别引出三条轴(0,1,2)
获取0轴上的第一个数据:
看正视图,得出[[5,7,2,4],[6,7,5,8],[9,8,9,8]]
获取1轴上的第一个数据:
看俯视图,得出[[5,7,2,4],[6,6,7,1],[1,2,3,4]]
获取2轴上的第一个数据:
看左视图,得出[[5,6,9],[6,7,4],[1,4,6]]
下面是代码举例(和上图数据无关):
import torch
# 提前设置种子
torch.manual_seed(666)
# 创建三位张量
data = torch.randint(1, 10, (3, 4, 5))
print(data)
print(data[:, :, :])
# 格式: data[0轴索引,1轴索引,2轴索引]
print('----------------------------------')
# 获取0轴上的第一个数据
print(data[0, :, :])
print(data[0])
print('----------------------------------')
# 获取1轴上的第一个数据
print(data[:, 0, :])
print('----------------------------------')
# 获取2轴上的第一个数据
print(data[:, :, 0])
六、张量形状和维度操作
1.形状问题
reshape修改形状元素个数不会改变,且可以修改连续和非连续张量。
import torch
#创建张量
t = torch.tensor([[1, 2, 3], [4, 5, 6]])
#shape获取形状
print(t.shape, t.shape[0], t.shape[1], t.shape[-1])
print(t.size(), t.size()[0], t.size()[1], t.size()[-1])
# TODO reshape修改形状(元素个数不会变化)
print(t.reshape(3, 2))
print('--------------------')
2.降维升维问题
降维:tensor.squeeze(dim=),默认删除所有维度值为1的维度, 可以通过dim指定删除维度
升维:tensor.unsqueeze(dim=),在指定维度上增加维度值1
维度升高只能一层一层加,不能超出范围
import torch
# 创建张量
t = torch.tensor([1, 2, 3, 4, 5, 6])
#升维
t4 = t.unsqueeze(dim=0)
print(t4, t4.shape, t4.ndim)
t5 = t.unsqueeze(dim=1)
print(t5, t5.shape, t5.ndim)
#降维
t6 = t4.squeeze()
print(t6, t6.shape, t6.ndim)
t7 = t5.squeeze()
print(t7, t7.shape, t7.ndim)
3.交换维度数据
transpose方式:每次只能交换两个维度;
permute方式:一次指定多个维度。
import torch
# 提前设置一个种子
torch.manual_seed(666)
# 创建三维张量
t = torch.randint(1, 5, (3, 4, 5))
#把张量形状(3,4,5)转变为(4,5,3)
#方式1:transpose方式:多次交换,每次只能交换两个维度
t2 = t.transpose(1, 0) # (3,4,5)->(4,3,5)
print(t2, t2.shape)
t3 = t2.transpose(2, 1) # (4,3,5)->(4,5,3)
print(t3, t3.shape)
#方式2:permute方式: 一次指定多个维度
t4 = t.permute(dims=(1, 2, 0)) # (3,4,5)->(4,5,3)
print(t4, t4.shape)
4.判断是否连续以及修改操作
is_contiguous()判断张量是否连续
import torch
# 创建张量
t1 = torch.tensor([[10, 20, 30], [40, 50, 60]])
t2 = t1.transpose(1, 0)
print(t2, t2.shape, t2.ndim)
# TODO is_contiguous()判断张量是否连续
# 判断t1是否连续
print(t1.is_contiguous())
# 判断t2是否连续
print(t2.is_contiguous())
# TODO 演示reshape和view的区别:
# todo 1.reshape连续和不连续张量都能操作,以后工作中常用reshape()
print(t1.reshape(1, 6))
print(t2.reshape(1, 6))
# todo 2.view只能操作连续张量,如果是不连续则需要使用contiguous()
print(t1.view(1, 6))
# print(t2.view(1, 6)) # 报错,因为t2是不连续的
# TODO 使用contiguous()把数据变为连续的,然后使用view()
t3 = t2.contiguous()
print(t3.is_contiguous())
print(t3.view(1, 6))
七、张量的拼接
1.cat()拼接
除需要拼接的那一维度外,其他所有张量形状必须完全相同。
2.stack()拼接
所有张量形状必须完全相同(所有维度一致)。
import torch
torch.manual_seed(666)
# TODO cat()拼接三维案例:对应维度上的维数相加
# 创建张量
t1 = torch.randint(1, 5, (1, 2, 3))
t2 = torch.randint(1, 5, (1, 2, 3))
# 拼接
print(torch.cat([t1, t2], dim=0)) # (2, 2, 3)
print('----------------------------')
# 拼接
print(torch.cat([t1, t2], dim=1)) # (1, 4, 3)
print('----------------------------')
# 拼接
print(torch.cat([t1, t2], dim=2)) # (1, 2, 6)
print('=========================================================')
t1 = torch.randint(1, 5, (2, 3))
t2 = torch.randint(1, 5, (1, 3))
# 拼接0轴上
print(torch.cat([t1, t2], dim=0)) # (3, 3)
print('=========================================================')
# TODO stack()拼接的是所有张量的形状必须完全相同(所有维度一致)。
t1 = torch.randint(1, 5, (2, 3))
t2 = torch.randint(1, 5, (2, 3))
print(torch.stack([t1, t2], dim=0)) # (2,2,3)
print(torch.stack([t1, t2], dim=1)) # (2,2,3)
print(torch.stack([t1, t2], dim=2)) # (2,3,2)