这里用到torch模块,可以在线下载数据集。识别十个数字。
前几次迭代期结果已经达到95%以上的准确率,最终结果在96%左右稳定,过拟合和欠拟合风险较小。CPU训练速度有些慢。
import torch.nn as nn
import torch
import torchvision
from torch.utils.data import DataLoader
import time
print("start")
EPOCH = 50 # 总的训练次数
BATCH_SIZE = 20 # 批次的大小
LR = 0.03 # 学习率(交叉熵损失函数不需要太大的学习率)
DOWNLOAD_MNIST = False # 运行代码的时候是否下载数据集(需要下载数据集改为True)
cuda_available = torch.cuda.is_available() # 获取GPU是否可用,可用的话就用GPU进行训练和测试
# #对于这样的网络,可能cpu更快一些
cuda_available = False # 即使gpu可用,也可以执行这一句,测试训练在cpu上的训练速度
# 设置一个转换的集合,先把数据转换到tensor,再归一化为均值.5,标准差.5的正态分布
trans = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(), # ToTensor方法把[0,255]变成[0,1]
torchvision.transforms.Normalize([0.5], [0.5])