1. 准备数据集dataset,及数据加载dataloader
我用的是CIFAR-10彩色图像数据集,CIFAR-10 是一个用于识别普适物体的小型数据集。一共包含10 个类别的RGB 彩色图片:飞机( airplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。CIFAR-10数据集中每个图片的尺寸为32 × 32 ,每个类别有6000个图像,数据集中一共有50000 张训练图片和10000 张测试图片。
2.搭建网络模型
CIFAR 10 model结构
3.创建损失函数、优化器
4.设置训练参数(epoch...)
5.网络进入训练状态(调用model.train())
(1)从train_dataloader中加载数据
(2) 计算损失函数
(3) 反向传播,优化器优化
(4) print, tensorboard 展示输出
6. 每个epoch训练完成后,网络进入测试状态
(1) 在with torch.no_grad下进行(只测试,无梯度优化)
(2) 从test_dataloader中加载数据
(3) 计算指标(loss,acc),展示模型效果