import torch
# 导入torchvision的transform模块,用来处理数据集
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# 创建一个transforms实例,用于对数据集进行处理
transform = transforms.Compose([
# 将图片转换为tensor
transforms.ToTensor(),
# 归一化
transforms.Normalize((0.1307,), (0.3081,))
])
# 创建训练集
"""
datasets.MNIST是Pytorch的内置函数torchvision.datasets.MNIST,通过这个可以导入数据集。
root设置数据集路径,如果该路径中有需要的数据,则使用,如果没有且download为True,则下载数据集到目录中。
train=True 表明这是训练集
transform=transform 表明使用上面定义的transform来处理数据集。
"""
# 创建训练集
train_dataset = datasets.MNIST(root='DataSet\\',
train=True,
download=True,
transform=transform)
# 创建测试集
test_dataset = datasets.MNIST(root='DataSet\\',
机器学习/深度学习--手写数字识别(MNIST数据集)
于 2023-03-07 16:04:29 首次发布
本文通过深度学习方法对MNIST数据集进行手写数字识别,经过10轮训练,模型达到97.8%的识别准确率。实操中,将测试集替换为自定义图片,展示预测效果。
摘要由CSDN通过智能技术生成