摘要前面我们介绍了正向传播, 反向传播, 梯度下降法. 也介绍了Pytorch中的损失函数和优化器, 数据加载器, 数据预处理, 和交叉熵. 这一篇, 我们使用之前学习到的所有知识, 建立一个全连接的神经网络, 来完成手写字符的识别.
简介
前面我们介绍了正向传播, 反向传播, 梯度下降法. 也介绍了Pytorch中的损失函数和优化器, 数据加载器, 数据预处理, 和交叉熵. 这一篇, 我们使用之前学习到的所有知识, 建立一个全连接的神经网络, 来完成手写字符的识别.
之前我曾经写过一个版本的使用Pytorch实现手写字符的识别, 当时主要目的是为了实现如何动态修改网络结构, 所以总的结构不是很完整, 这一篇会写的比较完整. 上一个版本的链接: Pytorch模型实例-MNIST dataset
全连接网络完成手写数字识别
准备工作
在准备工作阶段, 我们需要导入需要的库, 并且判断实验环境是否支持GPU, 还是只能使用CPU.
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
数据加载与数据预处理
我们使用MNIST数据集, 该数据集可以使用torchvision.datasets.MNIST获得. 这一阶段的任务如下所示:
创建dataset
加载MNIST数据
进行数据预处理, 转换为tensor
创建dataloader
将dataset传入dataloader, 设置batchsize
首先我们创建dataset, 同时设置数据预处理.
# 将数据集合下载到指定目录下,这里的transform表示,数据加载时所需要做的预处理操作
# 加载训练集合(Train)
train_dataset = torchvision.datasets.MNIST(root='./data',
train=True,
transform=torchvision.transforms.ToTensor(),
download=True)
# 加载测试集合(Test)
test_dataset = torchvision.datasets.MNIST(root='./data',
train=False,
transform=transforms.ToTensor())
print(train_dataset) # 训练集
"""
Dataset MNIST
Num