本人大二,在学习pytorch,写这个纯粹是为了怕自己忘记,以及整理、理顺一下学到的东西。
一、PyTorch的安装
我使用的是Pycharm,直接在Python Packages里搜索Pytorch,选择pytorch-ignite进行安装。这个包应该是高配版的pytorch,里面多了一些大佬们写好的功能。我最初其实选择的是pytorch那个包进行安装,但是失败了,报错我也看不懂,所以就试了试ignite,结果成功了。
import torch
上述语句可以进行包的引用。
二、Tensor张量
类似于numpy,Pytorch有自己的一套管理矩阵的基本数据结构,名为tensor。使用tensor的一大好处是可以利用gpu加快计算速度(虽然说我也不懂gpu为啥能进行计算,先记住吧,嗯啊。。)
此处省略一些tensor初始化的操作。。。
需要一提的是tensor类有几个属性:
以使用torch.tensor()进行初始化时为例,可以设置参数(data,dtype,device,requires_grad)这几个参数的含义为:
1.data:tensor所存储的数据。直接用形如[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]这样的数组传入初始化就可以。
2.dtype:tensor存储数据的类型。具体有float16、float32、float64,int8、int16、int32、int64,double。但是好像不同数据类型间的tensor不能进行运算,至少我用torch进行matrix multiply会报错。
3.device:tensor以什么硬件进行运算。不加说明自动创建cpu类型的tensor,最常应用的就是cpu和gpu,在gpu的叫做cuda,如果想要看看自己的gpu可不可以用,可以用torch.cuda.is_available()进行进行检测。这个参数其实是一个Union,还可以填写cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, mlc, xla, vulkan, meta, hpu等,有一说一,别的我都不认识。
4.requires_grad:这是一个bool值,若为True则torchtorch便创建另一个类名叫Function,来对tensor的计算进行跟踪,你可以使用print(tensor.grad_fn)来进行查看,从而可以进行反向传播求的梯度(用以对神经网络的参数进行优化)
5.grad:(其实torch.tensor()这个函数里没有这个参数,这是tensor的属性)这个数据类型是一个和data形状一样的tensor,用来存放计算出来的梯度。
三、计算图
机器视觉领域使用计算图进行求导运算,这一过程其实和高数里求导数,偏导本质是一样的。值得一提的是,在计算图中每一个数据用一个节点表示,最初进行计算的几个数据叫做leaf_node,其实这个数据也是tensor。只能对是leaf_node的tensor访问其.grad属性,否则会报错。
四、tensor.backward()
这个方法函数是用来对tensor经历的运算过程进行反向传播求导的,也就是说之前将requires_grad标记为True,只起到了让torch用Function类对tensor的计算过程进行记录的作用,每一步的梯度并没有被计算。只有对运算出来的tensor调用.backward()方法函数才会进行梯度的计算,并把梯度存在leaf_node(这是一个tensor,是进行运算的最原始数据).grad里。
backward()方法函数有两个用的到的参数为(gradient,retain_graph):
1:gradient:这个数据类型是tensor。先说点题外的,一个tensor运算的结果可能是一个标量scalar(哎呀,其实就是一个数)比方说进行了tensor.mean()运算(求取平均值);也可能是一个向量,这个就比较常见了,比方说矩阵点乘。如果对一个标量调用.backward()那不会出什么问题,但当对一个向量进行.backward()时会报错:grad can be implicitly created only for scalar outputs。所以需要一个gradient,而最后计算的梯度结果(此时的计算结果是一个向量,即矩阵)将会点乘这个gradient(!!这里的点乘和线性代数里面的矩阵乘法不一样!!可以看这篇博客Python 之 numpy 和 tensorflow 中的各种乘法(点乘和矩阵乘) - 刘[小]倩 - 博客园)torch这么做可能是为了利用gradient过滤掉计算结果其他维度的数据吧。
2:retain_graph:这是一个bool值。为了节省内存每次进行梯度计算后 ,计算图就会被删除,将此参数标记为True则不会。