前言
神经网络是一种很特别的解决问题的方法。本书将用最简单易懂的方式与读者一起从最简单开始,一步一步深入了解神经网络的基础算法。本书将尽量避开让人望而生畏的名词和数学概念,通过构造可以运行的Java程序来实践相关算法。
关注微信号“逻辑编程"来获取本书的更多信息。
这一章节我们将会解决一个真正的问题:手写字体识别。我们将识别像下面图中这样的手写数字。
在开始之前,我们先要准备好相应的测试数据。我们不能像前边那样简单的产生手写字体,毕竟我们自己还不知道如何写出一个产生手写字体的算法。训练要达到一定的精度需要较多的训练数据。还好,前人栽树后人乘凉,先驱们已经收集了宝贵的训练材料。MNIST就是一个广泛使用的数据集。不但可以拿来用,我们还可以从网站上看到别人的识别准确率。这样我们就有了很好的参照。MNIST包含一套训练数据和一套测试数据,分别来自不同的人群的手写。
MNIST网站: http://yann.lecun.com/exdb/mnist/
这个数据集是写在特定的二进制文件中的,并非普通图片格式。每个图片数据由28*28个像素组成。每个像素1个字节表示颜色灰度级。MNIST网站上有具体的介绍。
我们写一个类来完成数据集的读取工作,并提供接口返回指定的训练或者测试数据。具体代码不做分析,仅将代码附在下面,供读者使用。代码执行前要先下载数据文件并保留GZIP格式。代码执行后将随机抽取20个生成PNG图片供读者自己查看和验证数据内容。
下面我们写个测试类来识别手写字体。我们使用MNIST库的60000训练数据来反复训练我们的神经网络。每轮训练后使用MNIST库的10000个测试数据来测试识别率。
下面是代码:
package com.luoxq.ann;
import java.util.Arrays;
import java.util.Random;
public class MnistTest {
public static void main(String... args) {
int[] shape = {28 * 28, 10};
NeuralNetwork nn = new NeuralNetwork(shape);
Mnist mnist = new Mnist();
mnist.load();
mnist.shuffle();
System.out.println("Shape: " + Arrays.toString(shape));
System.out.println("Initial correct rate: " + test(nn, mnist));
int epochs = 1000;
double rate = 0.5;
System.out.println("Learning rate: " + rate);
System.out.println("Epoch,Time,Correctness\n----------------------");
long time = System.currentTimeMillis();
Mnist.Data[] data = mnist.getTrainingSlice(0, 60000);
for (int epoch = 1; epoch <= epochs; epoch++) {
for (int sample = 0; sample < data.length; sample++) {
nn.train(data[sample].input, data[sample].output, rate);
}