java识别手写文字_神经网络入门 第6章 识别手写字体

前言

神经网络是一种很特别的解决问题的方法。本书将用最简单易懂的方式与读者一起从最简单开始,一步一步深入了解神经网络的基础算法。本书将尽量避开让人望而生畏的名词和数学概念,通过构造可以运行的Java程序来实践相关算法。

关注微信号“逻辑编程"来获取本书的更多信息。

这一章节我们将会解决一个真正的问题:手写字体识别。我们将识别像下面图中这样的手写数字。

a6128df41843037fb4ebd9bed195e96e.png

在开始之前,我们先要准备好相应的测试数据。我们不能像前边那样简单的产生手写字体,毕竟我们自己还不知道如何写出一个产生手写字体的算法。训练要达到一定的精度需要较多的训练数据。还好,前人栽树后人乘凉,先驱们已经收集了宝贵的训练材料。MNIST就是一个广泛使用的数据集。不但可以拿来用,我们还可以从网站上看到别人的识别准确率。这样我们就有了很好的参照。MNIST包含一套训练数据和一套测试数据,分别来自不同的人群的手写。

MNIST网站: http://yann.lecun.com/exdb/mnist/

这个数据集是写在特定的二进制文件中的,并非普通图片格式。每个图片数据由28*28个像素组成。每个像素1个字节表示颜色灰度级。MNIST网站上有具体的介绍。

我们写一个类来完成数据集的读取工作,并提供接口返回指定的训练或者测试数据。具体代码不做分析,仅将代码附在下面,供读者使用。代码执行前要先下载数据文件并保留GZIP格式。代码执行后将随机抽取20个生成PNG图片供读者自己查看和验证数据内容。

下面我们写个测试类来识别手写字体。我们使用MNIST库的60000训练数据来反复训练我们的神经网络。每轮训练后使用MNIST库的10000个测试数据来测试识别率。

下面是代码:

package com.luoxq.ann;

import java.util.Arrays;

import java.util.Random;

public class MnistTest {

public static void main(String... args) {

int[] shape = {28 * 28, 10};

NeuralNetwork nn = new NeuralNetwork(shape);

Mnist mnist = new Mnist();

mnist.load();

mnist.shuffle();

System.out.println("Shape: " + Arrays.toString(shape));

System.out.println("Initial correct rate: " + test(nn, mnist));

int epochs = 1000;

double rate = 0.5;

System.out.println("Learning rate: " + rate);

System.out.println("Epoch,Time,Correctness\n----------------------");

long time = System.currentTimeMillis();

Mnist.Data[] data = mnist.getTrainingSlice(0, 60000);

for (int epoch = 1; epoch <= epochs; epoch++) {

for (int sample = 0; sample < data.length; sample++) {

nn.train(data[sample].input, data[sample].output, rate);

}

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值