第3章 多层全连接神经网络
3.1 热身:PyTorch基础
3.1.1 Tensor(张量)
Tensor,张量,是一个多维矩阵,零维矩阵是一个点,一维是向量,二维是一般的矩阵,多维就相当于一个多维数组。PyTorch的Tensor可以和numpy的ndarray相互转换,PyTorch可以在GPU上运行。
不同数据类型的Tensor,torch.FloatTensor, torch.DoubleTensor, torch.ShortTensor, torch.IntTensor, torch.LongTensor。torch.Tensor默认的是torch.FloatTensor数据类型,也可以定义所需要的类型。
b.numpy()能将tensor转换为numpy数据类型,torch.from_numpy()能将numpy转换为tensor。
torch.cuda.is_available()用于判断是否支持GPU,a.cuda()能够将tensor a放到GPUh 。
3.1.2 Variable
variable提供自动求导功能,variable会被放入一个计算图中,然后进行前身传播,反向传播,自动求导。
variable是在torch.autograd.Variable中,Variable(a)可以将一个tensor a变成Variable。variable比较重要的属性有data,grad和grad_fn。通过data可以取出variable里的tensor的数值,grad_fn表示得到这个Variable的操作,grad是这个Variable是反传播梯度。构建variable时,传入参数requires_grad=True,表示对这个变量求梯度,默认是False。y.backward()就是自动求导,自动求导不需要你再去明确地写明哪个函数对哪个函数求导,直接通过这行代码就对所有的需要梯度的变量进行求导,得到它们的梯度,然后通过x.grad可以得到x的梯度。
3.1.3Dataset
PyTorch提供了很多工具使得数据的读取和预处理变得很容易。
torch.utils.data.Dataset是代表这一数据的抽象类,可以定制自己的数据类继承和重写这个抽象类,只需要定义__len__和__getitem__这两个函数。
重写Dataset方法这种方式,可以通过迭代的方式取得每一个数据,但是这样很难取batch,shuffle或者多线程去读取数据,PyTorch中定义了其他方法,torch.utils.data.DataLoader来定义一个新的迭代器。
另外,torchvision包中还有一个关于计算机视觉的数据读取类:ImageFolder,主要功能是处理图片,要求图像按root/dog/xxx.png的形式存放。
dataset = ImageFolder(root='root_path', transform=None, loader=default_loader)
其中的root需要是根目录,在这个目录下有几个文件夹,每个文件夹表示一个类别:transform和target_transform是图片增强;loader是图片的读取办法,然后通过loader将图片转换成我们需要的图片类型进入神经网络。
3.1.4nn.Module(模组)
在PyTorch里编写神经网络,所有的层结构和损失函数都来自于torch.nn,所有的模型构建都是从这个基类nn.modeule继承的。定义完模型之后,需要通过nn这个包来定义损失函数,常见的损失函数已经定义在nn中了,比如均方误差、多分类的交叉熵,以及二分类的交叉熵等等。criterion = nn.CrossEntropyLoss() loss = criterion(output,target),这样就可以算出Loss了,也可以根据自己需求定制loss。
3.1.5torch.optim(优化)
通常需要通过修改参数使得损失函数最小化,优化算法就是一种调整模型参数更新的策略。
优化算法分为两大类:一阶优化算法,二阶优化算法 。
torch.optim是一个实现各种优化算法的包,大多数常见的算法都能够直接通过这个包来调用,比如随机梯度下降,以及添加动量的随机梯度下降,自适应学习率等。
在调用的时候将需要优化的参数传入,这些参数都必须是Variable,然后传入一些基本的设定,比如学习率和动量等。
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
3.1.6模型的保存和加载
在PyTorch里使用torch.save来保存模型的结构和参数,有两种保存方式,
(1)保存整个模型的结构信息和参数信息,保存的对象是模型model;
(2)保存模型的参数,保存的对象是模型的状态model.state_dict()。
save的第一个参数是保存对象,第二个参数是保存路径及名称;一般默认用.pth
师兄推荐看另一本更新的书
https://github.com/ShusenTang/Dive-into-DL-PyTorch