我们将构建一个三层的神经网络来处理手写数字识别问题,之后我们将运用AdaGrad、RMSprop、Momentum、Nesterov Momentum和Adam优化算法来加速梯度下降的过程,首先我们先来实现一个简单的神经网络。
文章目录
1. 导入所需的Python库
# coding: utf-8
import numpy as np
import matplotlib.pyplot as plt
from utils import load_mnist
from collections import OrderedDict
2. 加载数据并可视化
先介绍一下在这个实验中所用到的数据库MNIST,MNIST数据集是一个手写体数据集,其中每个手写数字是一张28×28的灰度图片,图片的标记为一个0-9表示的数字。 MNIST数据集一共有60000张图片用来作为训练集,10000张图片来作为测试集。
我们知道一张灰度图片一般是二维的,但是神经网络中的全连接层的输入是一个一维的向量。所以我们需