MNIST是什么
MNIST是一组经过预处理的手写数字图片数据集,它为机器学习的初学者提供了一个练手的机会,可以在真实的数据上用学到的算法来解决问题。由于很多的机器学习教程都以MNIST作为入门项目,因此它也被称作是机器学习领域的“hello world”。
MNIST中每个样本都是一张长28、宽28的灰度图片,其中包含一个0-9的数字。我们需要做的,就是根据训练数据建立一个模型用来识别输入图片中的数字。这是典型的分类问题,每个样本的输入是784维向量:一张图片有28*28=784个像素点,每个点用一个浮点数表示其亮度;输出是10维向量,十个分量分别表示输入图中数字是0~9的可能性,其中可能性最大的,就是算法预测的结果。
安装TensorFlow
本文的代码采用Python和TensorFlow编写,所以需要一个Python开发环境
GPU版本
框架
在开始具体的算法之前,我们先搭建一个通用的框架。框架要完成一些不同算法都需要做的工作,比如加载数据集、定义和训练模型,验证模型准确率等等。这样后面实现具体算法的时候就只需要关注跟算法相关的代码。
机器学习的过程,就是用模型对训练数据进行拟合的过程。这里有两个核心,其一是“模型”。一个机器学习模型应该包括两个部分:从输入到输出的计算过程,也就是框架里的inference()函数;以及计算模型拟合程度的损失函数,也就是loss()函数。本文中的几种算法,还有其他更复杂的机器学习算法,都是一些经过验证具有实用价值的模型。机器学习算法的第二个关键是“拟合”,也就是在给定的模型和训练