这篇博客接上一篇博客Make Your Own Neural Network之代码详解上。本文也是出自Make your own Neural NetWork这本书。上一篇博客讲了神经网络类的功能模块,本文主要介绍如何对这个神经网络类进行训练和测试。
声明:
- 代码用的Python编写
- 红字部分为后面要讲的内容的中心句
MNIST数据集——手写汉字数据
识别人的手写数字是测试人造智能的理想挑战,因为这个问题足够困难和模糊。它不像许多递增或递减数列一样规律,没有很清楚的定义。手写数字识别也是图像识别的一个分支——使计算机正确分类图像包含什么的问题。这个问题经过了数十年的研究,只到最近才有良好的进展,而且像神经网络这样的方法一直是这些跨越式发展的关键部分。我们人类有时候也会对图像包含的内容产生分歧。因为我们会写出不同的手写数字图片,特别是如果人物是着急或不小心写的。为了让您了解图像识别问题的难度,请看下图,是4或9?
人工智能研究人员使用手写数字的图像集合作为一个受欢迎的设备来测试他们最新的想法和算法。事实上这个数据集常被用来检查我们最新的疯狂想法的图像识别方法,并与其他作品进行对比。也就是,用相同的数据集对不同的想法和算法进行测试。
数据来源
该数据集称为手写数字的MNIST数据库,可从传奇的神经网络专家Yann LeCun的网站下载。该页面还列出了新旧想法在学习和正确率方面的表现。我们会多次回到这个名单,看看怎么样我们自己的网络与专业网络的差别!MNIST数据库的格式不是最容易使用的,所以其他人又帮助创建了一个更简单格式的数据文件,如这个。这些文件称为CSV文件,这意味着每个值都是以逗号分隔的纯文本(逗号分隔值)。您可以在任何文本编辑器和大多数电子表格中轻松查看它们或数据分析软件将与CSV文件一起工作。它们几乎是一个通用标准。该网站提供两个CSV文件,这也将是本文将要使用的数据集:
- 训练集http://www.pjreddie.com/media/files/mnist_train.csv
- 测试集http://www.pjreddie.com/media/files/mnist_test.csv
顾名思义,训练集是用于训练神经网络的60000个标记示例的集合。标签表示输入具有所需的输出,也就是目标值。使用较小的10,000个测试集来查看我们的想法或算法是否有效。这个也是包含正确的标签,所以我们可以检查一下我们自己的神经网络是否得到了正确结果。相互独立的训练集合测试集是为了,确保我们测试时用的是以前没有看到的数据。否则,我们可以骗自己来简单地记住训练数据,以获得完美的、尽管是欺骗性的得分。将训练与测试的数据分离的思想在机器学习中是常见的。下面显示了加载这些MNIST测试集的一部分到一个文本编辑器。
哇!看起来有些事情出了问题!像80年代电脑被黑客入侵的电影之一。其实一切OK。显然,这些行由数字组成,用逗号分隔。由于一条数据相当长,所以它们换行了好几次。这个文本编辑器很贴心地在左边栏显示了真实的行号。现在,我们可以看到四条完整的手写数字样本,和第五条的部分数据。这些记录或文本行的内容很容易理解:
- 第一个值是标签,即手写应该表示的实际数字,例如“7”或“9”。这就是神经网络试图学习正确的答案。
- 后续值(逗号分隔)是手写数字的像素值。像素阵列的大小是28×28,所以在标签后面有784个值。
每一条记录的第一个值所示的数字“5”,代表该行的其余文本是某人的手写编号5的像素值。第二个记录表示手写的“0”,第三个表示“ 4“,第四纪录是”1“和第五代表“9”,分别代表后面的是图片0、4、1、9的像素列表。您可以从MNIST数据文件中选择任何一行,第一个数字将告诉您以下图像数据的标签。但是,很难看出784个值的长列表如何构成了某人的手写编号5的图片。我们应该将这些数字作为图像进行展示,以确认它们是手写编号的像素值。在我们深入了解之前,我们应该下载较小的MNIST数据集。 毕竟MNIST数据数据文件相当大,使用较小的子集比较方便。因为它意味着我们可以尝试,试用和开发我们的代码,而不会因为操作一个大的数据集减慢了我们的电脑。一旦我们解决了一个算法和代码,我们就可以开心的使用完整的数据集了。
以下是MNIST数据集的较小子集的链接,也是CSV格式:
# Python读取文件
data_file = open(r"E:\keras\sample\famous_data\mnist_train_100.csv", 'r') #一个可读的文件对象
# 因为文件不大,所以一次读了整个文件。理论上应该一行一行的读。
data_list = data_file .readlines() # 返回一个数据列表,data_list[i]表示第i样本
data_file.close()
可以看到第一条记录data_list [0]的内容:第一个数字是“5”,它是标签,784个数字中的其余数字是构成图像的像素的颜色值。如果你仔细观察,你可以告诉这些颜色值的范围在0到255之间。你可能想查看其他记录,看看那是否也是真的。您会发现颜色值确实在0到255之间。我们可以使用imshow()
函数绘制一个矩形数组的数组,但是我们需要将逗号分隔的数字列表转换成合适的数组。
以下是执行此操作的步骤:
- 将逗号分隔的长文本字符串分隔为各个值,使用逗号作为拆分的位置。
- 忽略作为标签的第一个值,并取剩余的28 * 28 = 784个列表,并将其转换为28行×28列的形状。
- 绘制阵列!
同样,最简单的方法就是显示相当简单的Python代码,并且通过代码来更详细的解释发生了什么。
# 首先,我们不能忘记导入Python扩展库,这将帮助我们使用数组和绘图: