一、Chinese MNIST数据集
本数据集来自Kaggle网站 Chinese MNIST | Kaggle
主要包括15000张64*64的手写中文数字图片,和一份内容文件。
二、神经网络结构
三层全连通网络:4096*300*80*15
三、传播过程
BP算法的计算过程可参考之前的文章,有详细说明,不再赘述。
四、本项目重点:数据集的载入
这里主要采用通过文件名获得标签的方法。具体实现过程可参考B站视频教程【绝对干货】pytorch加载自己的数据集,数据集载入-视频合集
五、程序(pytorch)
# 1 加载必要的库
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
import torchvision
import os
from PIL import Image
from torch.utils.data import DataLoader,Dataset
import matplotlib.pyplot as plt
from sklearn import preprocessing
# 2 定义超参数
batch_size = 128 #训练每批处理的数据
num_epochs = 10 #训练数据集的轮次
# 3 下载、加载数据
path_dir = "F:\\JetBrains\\Pycha