tensor基本操作
import torch
import numpy as np
DEVICE='cuda' if torch.cuda.is_available() else 'cpu'
# 初始化方法1
x=torch.tensor([[1,2,7],[3,4,5]],dtype=torch.float32,device=DEVICE,requires_grad=True)
# 初始化方法2
input=torch.rand((3,3))
x=torch.rand_like(input) # 产生与input相同尺寸的tensor 数据时均匀分布的
x=torch.ones((2,3)) # 全1
# 常见方法
torch.is_tensor(x) # 判断是不是一个tensor
torch.numel(x) # 统计tensor中元素的个数
torch.zeros((2,3)) # 创建全0的tensor
torch.eye(3,3) # 创建对角阵
x=np.array([1,2,3])
x=torch.from_numpy(x) # 将numpy装换成tensor
torch.linspace(start=2,end=10,steps=5) # 将2-10直接的数字切分为5份
torch.rand((2,3)) # 创建数据为均匀分布
torch.randn((2,3)) # 创建数据为正态分布
torch.randperm(10) # 得到0-9 10个数 但是顺序打乱
torch.arange(start=0,end=10,step=1) #生成0-10区间的数 步长为1 但是不包含10
x=torch.randint(low=1,high=5,size=(2,3)) #可以规定元素的最大最小值
torch.argmax(x,dim=0) # 获取每一列的最大值的下标 dim=0每一列 dim=1是每一行
x=torch.cat((x,x),dim=0) # 将两个tensor堆叠起来 dim=0是纵方向的堆叠 dim=1是横方向的堆叠
list= torch.chunk(x,chunks=2,dim=0) #将一个tensor切开 chunks表示切到哪里为止 返回一个tensor列表
index=torch.tensor([0,2])
x=torch.index_select(x,dim=0,index=index) # 根据索引选择
torch.split(x,3,dim=0) # 将tensor切分为若干份 每份是3个元素
x.t() # 矩阵转置
torch.add(x,1) # 每个元素+1
torch.mul(x,2) # 每个元素乘上2
# 维度调整
x.reshape/view
李宏毅老师课上的快速入门pytorch补充(关键是多了图片)
torch.zeros([1,2,3]) # 创建一个1*2*3立方体 这个立方体的高是1,长是2,宽是3
torch.squeeze(dim=0) # 降维,消除的是第1个维度 如将1*2*3变为2*3的矩形
torch.zeros([2,3]) # 创建一个2*3矩阵
torch.squeeze(dim=1) # 升维,增加的是第2个维度 如将2*3的矩形变为2*1*3的立方体
x=torch.zeros([2,1,3])
y=torch.zeros([2,3,3])
z=torch.zeros([2,2,3])
w=torch.cat([x,y,z],dim=1) # 在第一个维度(也就是长)将他们拼起来 w.shape=[2,6,3]
方便的梯度计算
Pytorch函数原型中,若是出现*号,则后面的参数全部都是keyword arguments,也就是用到这些参数时候,必须指定参数名
常见错误
*x is a keyword argument
:
对于函数原型中*之后的参数全部都是关键参数,如果使用必须书写参数名
*did not specify dim
:
有些参数,没有默认值,一定要赋值
Tensor for * is on CPU, but expected them to be on GPU
:使用设备要配对,要么GPU,要么CPU,可以全局定义好device,然后统一用它
The size of tensor a (5) must match the size of tensor b (4) at non-singleton dimension 1
:
tensor运算的时候尺寸不对,可以使用transpose、unsqueeze、squeeze等操作将尺寸变为一致之后再处理
CUDA out of memory.
:显存爆了,把batch-size调小一点
expected scalar type Long but found Float
:类型不匹配