手写字母识别 深度学习 pytorch 线性层 EMNIST数据集 CUDA GPU训练 可使用图片测试 可视化 源代码 详细注释

基于广为流传的手写数字识别的训练代码改进而来

效果

在这里插入图片描述
文件目录
在这里插入图片描述

main.py

#神经网络的包
import torch
from torch import nn  # 神经网络相关工作
from torch.nn import functional as F  # 常用函数
from torch import optim  # 优化工具包
import torchvision #计算机视觉
from torchvision import transforms
#其他包
from matplotlib import pyplot as plt
from matplotlib.colors import Normalize
import random
#自己的函数
from utils import plot_image, plot_curve, one_hot

#超参数
batch_size = 512  #一次处理多张图片
LR=0.001

#1)加载数据集EMNIST
#原来的图片进行过处理:水平翻转图像,然后逆时针旋转90度
# 因此,我们处理时,要先顺时针旋转90度,再次水平翻转
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.EMNIST(
        './data', 
        train=True, 
        download=True,
        # 下载的数据为numpy格式转换为tensor格式,正则化使原本[0,1]的数据在0附近均匀分布
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.1307,),(0.3081,)),
            transforms.RandomRotation(degrees=(90, 90)), #旋转90度
            transforms.RandomVerticalFlip(p=1),  #水平翻转(概率=1)
        ]),
        split="letters"#只使用字母集进行测试
    ),
    batch_size=batch_size,#批量大小
    shuffle=True#随机打散
)  # shuffle把数据随机打散
 
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.EMNIST(
        './data', 
        train=False, 
        download=True,
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.1307,),(0.3081,)),
                transforms.RandomRotation(degrees=(90, 90)),   
                transforms.RandomVerticalFlip(p=1)
        ]),
        split="letters"
        ),
    batch_size=batch_size,
    shuffle=True
) 

#2)创建网络
class Net(nn.Module):
    #初始化
    def __init__(self):
        super(Net,self).__init__() 
        #线性层,每一层为xw+b
        self.fc1 = nn.Linear(28*28,
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值