pytorch学习 (第一个周)
1、步骤
- 下载数据集
- 新建层
- 训练
- 测试
# utils
# 画曲线、绘图、one_hot 编码
#头文件
import torch
from torch import nn
from torch.nn import functional as F
from torch inport optim
import torchivision
from matplotlib import pyplot as plt
from utils import plot_image,plot_curve,one_hot
# 下载 图片资源
# 加载数据集 train、test 均匀分布在0附近
# torch.Size([512,1,28,28]) 含义:512张图片,1个通道,28行28列
基本数据类型
int->IntTensor of size()
float->FloatTensor of size()
byte-> ByteTensor
int[] ->IntTensor of size()[d1,d2,…]
float[]->FloatTensor of size() [d1,d2,…]
string->没有string类型,用整型表示采用one-hot 编码方式
一维向量
应用于:偏置bias,Linear input线性输入
dim 维度
size 具体形状维度x*y…
二维
三维
RNN Input Batch
四维
CNN [b,c,h,w]
适合表示图片
[图片数,通道,长,宽]
内存大小
a.numel()
a.dim() <=> len(a.shape)
区别:Tensor = FloatTensor(数据的维度或者[list表示的具体数据] );tensor(具体的数据)
注意:tensor\Tensor的类型为默认类型,一般为FloatTensor
但是为了增强学习,一般采用DoubleTensor
torch.set_default_tensor_type(torch.DoubleTensor)# 设置tensor的默认类型
创建数据类型
1、未初始化的数据,数据不规则,需要赋值
Torch.empty(1) # 未初始化,长度为1
Torch.FloatTensor(d1,d2,d3)
Torch.IntTensor(d1,d2,d3)
2、随机初始化
torch.rand(3,3) #[0,1] 随机初始化 3*3矩阵
torch.randint(1,10,[3,3]) # 第一个参数最小值,第二个参数最大值,[3,3] 表示矩阵大小
3、正态初始化
torch.randn(3,3)
torch.normal(mean=torch.full([10],0),std=torch.arange(1,0,-0.1))
#维度为1,长度为10
4、赋值为同一个数值
torch.full([2,3],7) # 2*3的矩阵,每个元素为7
torch.full([],7) # 标量
torch.full([2],7) # 向量,【7,7】
5、等差数列
6、等切
第三个参数为切割份量
7、矩阵全1or0or对角矩阵
torch.ones(3,3)
torch.zeros(3,3)
torch.eye(3,4) # (i,i) 位置为1,其他为0
torch.eye(3) # 3*3矩阵
8、随机打散种子
a.b的索引多应相同
索引与切片
索引
1、“:” 单独出现表示取全部
2、“1:” 表示从1到末尾(包括1)
3、“:1” 表示从0到1(不包括1)
4、注意:[0 1 2] 反向标记为[-3 -2 -1]
5、“0:28:2” [0,28) 第三个冒号后面是步长
6、a.index_select(0,[0,2]) 选择第0个参数
选择第2个维度(0开始),范围设定8
7、“…"表示多个通道维度, 例如:a[…].shape 任意多的维度
8、mask 设置掩码,这里是x矩阵中大于0.5的,不过后面变为1*x的矩阵
Tensor维度变换
1、合并信息 a.view() 可能会造成数据污染
2、unsqueeze 加入一个维度
插入维度 a.unsqueeze(0) 表示在0的前面加一个维度(负数在后面加一个维度)
范围:
3、squeeze 删除或者减掉一个维度
不传参,挤压掉所有1;传入索引,挤压掉该索引位置,如果是1挤压,否则挤压。
4、extend 扩展
使用条件:前后维度必须一致;1->N才能扩张,不可3->M;-1表示不进行扩张
5、repeat 拷贝
该维度上拷贝的次数,维度一样
6、矩阵转置
.t 转置二维!
7、permute
交换维度:permute(x,y,z,r) 标记原来的维度编号
自动扩展Broadcast
按照需要自动,高维度叉1,复制date;右边为小维度,从小维度开始加
拼接与拆分
1、cat 在一个维度上进行拼接
torch.cat([list],dim) 第一个参数将那几个进行拼接;第二个参数,在哪一个维度上进行拼接;确保其他维度上的值一定
2、stack([a1,a2],dim=2)创建一个新的维度
3、split 根据长度拆分
c.split([3,2],dim=0) 对于第零个维度进行拆分,分为3,2两个组
4、chunk
按数量拆分,c.chunk(2,dim=0) 对于第零个维度,拆分为两个组
数学运算
1、加、减
2、“.matmul” 矩阵方式相乘,只取后面的两个维度进行相乘,前面的两维不变
3、“.* ” 对应位置相乘
4、平方,a.pow(2) 平方;两个**也表示平方
5、开方,sqrt(2)
6、立方,
7、近似值
floor() 往下取整;ceil()往上取整
clamp(10) 小于10的值都变为10;clamp(0,10)限制在0-10之间
统计属性
1、norm 范数
2、mean sum/所有元素
3、argmax() 最大值的索引,argmin()最小值的索引(打平后一维的索引)
4、dim
5、keepdim=true保持原来的维度
6、eq 判断每一个元素,equal 判断整个矩阵
防止过拟合的方法——regularization