本地模块:
getdata.py:
定义一个转换关系,用于将图像数据转换成tensor形式:torchvision.transforms模块
定义项目数据集类,继承自torch.utils.data.Dataset类,实现该抽象类的__getitem__方法(根据索引取图片)和__len__方法
network.py:
定义项目的神经网络模型,继承自torch.nn.Module类,实现__init__(定义神经网络结构,包括各个层的初始化和参数设置)和forward(定义神经网络的前向传播过程)两个方法。
训练:
train.py:
导入两个本地模块
导入pytorch官方模块
定义全局变量
写训练方法:
1.实例化数据集
2.引入torch.utils.data.DataLoader包装该数据集
3.实例化一个网络
4.网络送入GPU,即model=model.cuda()
5.网络设定为训练模式:model.train()
6.实例化优化器,即调整网络参数的方式,引自torch.optim.Adam
7.定义loss计算方法:criterion = torch.nn.CrossEntropyLoss()
8.进入训练轮次
数据集地址:
代码地址:
https://github.com/xbliuhnu/DogsVsCats
注:加粗字块均可以去官网阅读源码