《深度学习之Pytorch》学习笔记

第3章 多层全连接神经网络

3.1 热身:PyTorch基础

3.1.1 Tensor(张量)

Tensor,张量,是一个多维矩阵,零维矩阵是一个点,一维是向量,二维是一般的矩阵,多维就相当于一个多维数组。PyTorch的Tensor可以和numpy的ndarray相互转换,PyTorch可以在GPU上运行。

不同数据类型的Tensor,torch.FloatTensor, torch.DoubleTensor, torch.ShortTensor, torch.IntTensor, torch.LongTensor。torch.Tensor默认的是torch.FloatTensor数据类型,也可以定义所需要的类型。

b.numpy()能将tensor转换为numpy数据类型,torch.from_numpy()能将numpy转换为tensor。

torch.cuda.is_available()用于判断是否支持GPU,a.cuda()能够将tensor a放到GPUh 。

 

3.1.2 Variable

variable提供自动求导功能,variable会被放入一个计算图中,然后进行前身传播,反向传播,自动求导。

variable是在torch.autograd.Variable中,Variable(a)可以将一个tensor a变成Variable。variable比较重要的属性有data,grad和grad_fn。通过data可以取出variable里的tensor的数值,grad_fn表示得到这个Variable的操作,grad是这个Variable是反传播梯度。构建variable时,传入参数requires_grad=True,表示对这个变量求梯度,默认是False。y.backward()就是自动求导,自动求导不需要你再去明确地写明哪个函数对哪个函数求导,直接通过这行代码就对所有的需要梯度的变量进行求导,得到它们的梯度,然后通过x.grad可以得到x的梯度。

3.1.3Dataset

PyTorch提供了很多工具使得数据的读取和预处理变得很容易。

torch.utils.data.Dataset是代表这一数据的抽象类,可以定制自己的数据类继承和重写这个抽象类,只需要定义__len__和__getitem__这两个函数。

重写Dataset方法这种方式,可以通过迭代的方式取得每一个数据,但是这样很难取batch,shuffle或者多线程去读取数据,PyTorch中定义了其他方法,torch.utils.data.DataLoader来定义一个新的迭代器。

另外,torchvision包中还有一个关于计算机视觉的数据读取类:ImageFolder,主要功能是处理图片,要求图像按root/dog/xxx.png的形式存放。

dataset = ImageFolder(root='root_path', transform=None, loader=default_loader)

其中的root需要是根目录,在这个目录下有几个文件夹,每个文件夹表示一个类别:transform和target_transform是图片增强;loader是图片的读取办法,然后通过loader将图片转换成我们需要的图片类型进入神经网络。

3.1.4nn.Module(模组)

在PyTorch里编写神经网络,所有的层结构和损失函数都来自于torch.nn,所有的模型构建都是从这个基类nn.modeule继承的。定义完模型之后,需要通过nn这个包来定义损失函数,常见的损失函数已经定义在nn中了,比如均方误差、多分类的交叉熵,以及二分类的交叉熵等等。criterion = nn.CrossEntropyLoss() loss = criterion(output,target),这样就可以算出Loss了,也可以根据自己需求定制loss。

3.1.5torch.optim(优化)

通常需要通过修改参数使得损失函数最小化,优化算法就是一种调整模型参数更新的策略。

优化算法分为两大类:一阶优化算法,二阶优化算法 。

torch.optim是一个实现各种优化算法的包,大多数常见的算法都能够直接通过这个包来调用,比如随机梯度下降,以及添加动量的随机梯度下降,自适应学习率等。

在调用的时候将需要优化的参数传入,这些参数都必须是Variable,然后传入一些基本的设定,比如学习率和动量等。

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

3.1.6模型的保存和加载

在PyTorch里使用torch.save来保存模型的结构和参数,有两种保存方式,

(1)保存整个模型的结构信息和参数信息,保存的对象是模型model;

(2)保存模型的参数,保存的对象是模型的状态model.state_dict()。

save的第一个参数是保存对象,第二个参数是保存路径及名称;一般默认用.pth

 

师兄推荐看另一本更新的书

https://github.com/ShusenTang/Dive-into-DL-PyTorch

 

 

 

 

 

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值