基于广为流传的手写数字识别的训练代码改进而来
效果

文件目录

main.py
#神经网络的包
import torch
from torch import nn # 神经网络相关工作
from torch.nn import functional as F # 常用函数
from torch import optim # 优化工具包
import torchvision #计算机视觉
from torchvision import transforms
#其他包
from matplotlib import pyplot as plt
from matplotlib.colors import Normalize
import random
#自己的函数
from utils import plot_image, plot_curve, one_hot
#超参数
batch_size = 512 #一次处理多张图片
LR=0.001
#1)加载数据集EMNIST
#原来的图片进行过处理:水平翻转图像,然后逆时针旋转90度
# 因此,我们处理时,要先顺时针旋转90度,再次水平翻转
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.EMNIST(
'./data',
train=True,
download=True,
# 下载的数据为numpy格式转换为tensor格式,正则化使原本[0,1]的数据在0附近均匀分布
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,),(0.3081,)),
transforms.RandomRotation(degrees=(90, 90)), #旋转90度
transforms.RandomVerticalFlip(p=1), #水平翻转(概率=1)
]),
split="letters"#只使用字母集进行测试
),
batch_size=batch_size,#批量大小
shuffle=True#随机打散
) # shuffle把数据随机打散
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.EMNIST(
'./data',
train=False,
download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,),(0.3081,)),
transforms.RandomRotation(degrees=(90, 90)),
transforms.RandomVerticalFlip(p=1)
]),
split="letters"
),
batch_size=batch_size,
shuffle=True
)
#2)创建网络
class Net(nn.Module):
#初始化
def __init__(self):
super(Net,self).__init__()
#线性层,每一层为xw+b
self.fc1 = nn.Linear(28*28,

最低0.47元/天 解锁文章

4411

被折叠的 条评论
为什么被折叠?



