前言
神经网络是一种很特别的解决问题的方法。本书将用最简单易懂的方式与读者一起从最简单开始,一步一步深入了解神经网络的基础算法。本书将尽量避开让人望而生畏的名词和数学概念,通过构造可以运行的Java程序来实践相关算法。
关注微信号“逻辑编程"来获取本书的更多信息。
这一章节我们将会解决一个真正的问题:手写字体识别。我们将识别像下面图中这样的手写数字。
在开始之前,我们先要准备好相应的测试数据。我们不能像前边那样简单的产生手写字体,毕竟我们自己还不知道如何写出一个产生手写字体的算法。训练要达到一定的精度需要较多的训练数据。还好,前人栽树后人乘凉,先驱们已经收集了宝贵的训练材料。MNIST就是一个广泛使用的数据集。不但可以拿来用,我们还可以从网站上看到别人的识别准确率。这样我们就有了很好的参照。MNIST包含一套训练数据和一套测试数据,分别来自不同的人群的手写。
MNIST网站: http://yann.lecun.com/exdb/mnist/
这个数据集是写在特定的二进制文件中的,并非普通图片格式。每个图片数据由28*28个像素组成。每个像素1个字节表示颜色灰度级。MNIST网站上有具体的介绍。
我们写一个类来完成数据集的读取工作,并提供接口返回指定的训练或者测试数据。具体代码不做分析,仅将代码附在下面,供读者使用。代码执行前要先下载数据文件并保留GZIP格式。代码执行后将随机抽取20个生成PNG图片供读者自己查看和验证数据内容。
下面我们写个测试类来识别手写字体。我们使用MNIST库的60000训练数据来反复训练我们的神经网络。每轮训练后使用MNIST库的10000个测试数据来测试识别率。
下面是代码:
package com.luoxq.ann;
import java.util.Arrays;
import java.util.Random;
public class MnistTest {
public static void main(String... args) {
int[] shape = {28 * 28, 10};
NeuralNetwork nn = new NeuralNetwork(shape);
Mnist mnist = new Mnist();
mnist.load();
mnist.shuffle();
System.out.println("Shape: " + Arrays.toString(shape));
System.out.println("Initial correct rate: " + test(nn, mnist));
int epochs = 1000;
double rate = 0.5;
System.out.println("Learning rate: " + rate);
System.out.println("Epoch,Time,Correctness\n----------------------");
long time = System.currentTimeMillis();
Mnist.Data[] data = mnist.getTrainingSlice(0, 60000);
for (int epoch = 1; epoch <= epochs; epoch++) {
for (int sample = 0; sample < data.length; sample++) {
nn.train(data[sample].input, data[sample].output, rate);
}
long seconds = (System.currentTimeMillis() - time) / 1000;
System.out.println(epoch + ", " + seconds + ", " +
test(nn, mnist));
}
}
private static int test(NeuralNetwork nn, Mnist mnist) {
int correct = 0;
Mnist.Data[] data = mnist.getTestSlice(0, 10000);
for (int sample = 0; sample < data.length; sample++) {
if (max(nn.f(data[sample].input)) == data[sample].label) {
correct++;
}
}
return correct;
}
private static int max(double[] d) {
double max = d[0];
int idx = 0;
for (int i = 1; i < d.length; i++) {
if (max < d[i]) {
max = d[i];
idx = i;
}
}
return idx;
}
}
我们先用一个10个神经元的单层神经网络试试看。结果出乎意外的好。我们很快就获得了超过90%的正确率。单层网络几乎就是对每个数字的像素分布做简单统计。能获得如此高的识别率,还是很神奇的。 在达到90%之后再训练已经效果不大,达到饱和了。我们必须换一种方法来做了。
Shape: [784, 10]
Initial correct rate: 1373
Learning rate: 0.5
Epoch,Time,Correctness
----------------------
1, 4, 6429
2, 8, 7663
3, 13, 8963
4, 17, 9029
5, 22, 9016
6, 27, 9062
7, 31, 9063
8, 36, 9066
9, 41, 9072
10, 45, 9057
11, 50, 9084
12, 55, 9072
13, 61, 9062
14, 66, 9050
15, 70, 9077
16, 75, 9052
17, 79, 9068
18, 84, 9055
19, 88, 9060
20, 93, 9064
那么我们来使用三层神经网络试一试。在试了几个不同的中间层大小和学习率参数之后,我找到了下面这个较好的参数组合:
Shape: [784, 50, 10]
Initial correct rate: 944
Learning rate: 1.0
Epoch,Time,Correctness
----------------------
1, 24, 7459
2, 59, 9232
3, 99, 9313
4, 131, 9379
5, 153, 9412
6, 176, 9443
7, 200, 9412
8, 226, 9447
9, 248, 9462
10, 269, 9461
11, 290, 9465
12, 314, 9493
13, 343, 9477
14, 368, 9499
15, 392, 9502
16, 420, 9509
17, 447, 9482
18, 472, 9508
19, 496, 9491
20, 518, 9536
21, 545, 9523
22, 569, 9549
23, 593, 9527
24, 618, 9527
25, 643, 9520
26, 667, 9513
27, 689, 9507
28, 712, 9527
29, 734, 9501
30, 758, 9521
31, 781, 9508
32, 804, 9534
33, 827, 9534
34, 850, 9550
35, 875, 9569
我们很快达到了95%以上的正确率。可见多层网络相对单层神经网络还是有优势的。虽然这个正确率还达不到产品水平,但是作为初次尝试结果还是很不错的。
下面是MNIST文件读取源代码:
package com.luoxq.ann;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.util.Random;
import java.util.zip.GZIPInputStream;
/**
* Created by luoxq on 17/4/15.
*/
public class Mnist {
static class Data {
public byte[] data;
public int label;
public double[] input;
public double[] output;
}
public static void main(String... args) throws Exception {
Mnist mnist = new Mnist();
mnist.load();
System.out.println("Data loaded.");
Random rand = new Random(System.nanoTime());
for (int i = 0; i < 20; i++) {
int idx = rand.nextInt(60000);
Data d = mnist.getTrainingData(idx);
BufferedImage img = new BufferedImage(28, 28, BufferedImage.TYPE_INT_RGB);
for (int x = 0; x < 28; x++) {
for (int y = 0; y < 28; y++) {
img.setRGB(x, y, toRgb(d.data[y * 28 + x]));
}
}
File output = new File(i + "_" + d.label + ".png");
if (!output.exists()) {
output.createNewFile();
}
ImageIO.write(img, "png", output);
}
}
static int toRgb(byte bb) {
int b = (255 - (0xff & bb));
return (b << 16 | b << 8 | b) & 0xffffff;
}
Data[] trainingSet;
Data[] testSet;
public void shuffle() {
Random rand = new Random();
for (int i = 0; i < trainingSet.length; i++) {
int x = rand.nextInt(trainingSet.length);
Data d = trainingSet[i];
trainingSet[i] = trainingSet[x];
trainingSet[x] = trainingSet[i];
}
}
public Data getTrainingData(int idx) {
return trainingSet[idx];
}
public Data[] getTrainingSlice(int start, int count) {
Data[] ret = new Data[count];
System.arraycopy(trainingSet, start, ret, 0, count);
return ret;
}
public Data getTestData(int idx) {
return testSet[idx];
}
public Data[] getTestSlice(int start, int count) {
Data[] ret = new Data[count];
System.arraycopy(testSet, start, ret, 0, count);
return ret;
}
public void load() {
trainingSet = load("train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz");
testSet = load("t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz");
if (trainingSet.length != 60000 || testSet.length != 10000) {
throw new RuntimeException("Unexpected training/test data size: " + trainingSet.length + "/" + testSet.length);
}
}
private Data[] load(String imgFile, String labelFile) {
byte[][] images = loadImages(imgFile);
byte[] labels = loadLabels(labelFile);
if (images.length != labels.length) {
throw new RuntimeException("Images and label doesn't match: " + imgFile + " " + labelFile);
}
int len = images.length;
Data[] data = new Data[len];
for (int i = 0; i < len; i++) {
data[i] = new Data();
data[i].data = images[i];
data[i].label = 0xff & labels[i];
data[i].input = dataToInput(images[i]);
data[i].output = labelToOutput(labels[i]);
}
return data;
}
private double[] labelToOutput(byte label) {
double[] o = new double[10];
o[label] = 1;
return o;
}
private double[] dataToInput(byte[] b) {
double[] d = new double[b.length];
for (int i = 0; i < b.length; i++) {
d[i] = (b[i] & 0xff) / 255.0;
}
return d;
}
private byte[][] loadImages(String imgFile) {
try (DataInputStream in = new DataInputStream(new GZIPInputStream(new FileInputStream(imgFile)));) {
int magic = in.readInt();
if (magic != 0x00000803) {
throw new RuntimeException("wrong magic: 0x" + Integer.toHexString(magic));
}
int count = in.readInt();
int rows = in.readInt();
int cols = in.readInt();
if (rows != 28 || cols != 28) {
throw new RuntimeException("Unexpected row and col count: " + rows + "x" + cols);
}
byte[][] data = new byte[count][rows * cols];
for (int i = 0; i < count; i++) {
in.readFully(data[i]);
}
return data;
} catch (Exception ex) {
throw new RuntimeException("Failed to read file: " + imgFile, ex);
}
}
private byte[] loadLabels(String labelFile) {
try (DataInputStream in = new DataInputStream(new GZIPInputStream(new FileInputStream(labelFile)));) {
int magic = in.readInt();
if (magic != 0x00000801) {
throw new RuntimeException("wrong magic: 0x" + Integer.toHexString(magic));
}
int count = in.readInt();
byte[] data = new byte[count];
in.readFully(data);
return data;
} catch (Exception ex) {
throw new RuntimeException("Failed to read file: " + labelFile, ex);
}
}
}
欢迎关注订阅号逻辑编程阅读更多内容。