在pytorch中我么如何去学着搭建一个最基本的架构呢?
一个架构的搭建分为一下几步:
1:导入常用的包:torch,torch.nn,torch.functional等
2:将要处理的数据导入,这里不得不说,pytorch现阶段支持的数据集比较少,如果你要使用的数据集不在其支持的数据集列表里,那你就要自己编写程序进行导入了,这个会在后面的章节里详说
3:网络的搭建,写一个网络类,然后内部包含两个方法:
1》__init__()函数,这个主要是完成搭建材料的导入工作,如:self.relu = torch.nn.Relu()
2》forward()函数,这个主要是按照顺序搭建起整个网络来,最后得到结果
4:选择要使用的损失函数类型和优化器类型
5:框架最便利也是最吸引人的自动求导
1》zero_grad():将所有的参数的导数置零(为什么要置零?本人亲身试验过,如果不置零,其参数的导数会不断的累加,我们知道,我们在进行梯度下降的过程中当前的下降只用到当前的导数,下一个地点的下降是下个地点导数的事情,如果将其累加,就会出错)
2》losses.backward()这里的losses是我们使用的损失函数,对其进行反向求导
3》optim.step()这里的optim是我们选择的优化器,step是进行一步优化,也就是将我们上面back