Neural Network实战:Java实现Back Propagation算法 + 手写数字识别

本文介绍了作者通过Java实现经典神经网络,并使用Back Propagation算法进行训练的过程。详细讨论了BP算法的原理,以及如何利用它解决非凸优化问题。此外,还分享了手写数字识别的实践案例,将28*28的黑白图片作为输入进行识别。代码已上传至GitHub供参考。
摘要由CSDN通过智能技术生成

工作中时不时的会使用机器学习解决一些分类问题。但是一般都是使用已有的机器学习库,比如Weka,Scikit-learn等简单易用的库。

对于一个工程师, 够理解常见的机器学习模型(比如SVM,Naive bayes, Random forest, Neural network, Decision tree, 等等) + 降维技术(比如PCA, SVD, 等),使用已有的库, 来解决常见的工程问题。达到这个境界是不是就够了呢?

我个人觉得不是很满足。对于一些模型,已经训练模型的算法,虽然我看过不少书籍和文章,对于其中奥秘,自我感觉能够理解。但是作为工程师,不是自己实现的,用起来总是惴惴不安。或者自己能够实现过,再用现成的工业级别的库,也会心安理得一点。

那好,这篇文章,我就来用Java实现经典的Neural Network。训练的算法使用经典的Back Propagation算法。简单说两句BP算法:该算法使用最速梯度下降来求解目标函数的最小值。改目标函数是非凸的,因此使用梯度下降,容易求得次优解。一些解决办法包括:使用物理上冲量概念,当落入次优解的时候,算法本身有一定能力冲出次优解,再次滑向最优解。

本篇涉及到的所有code都放到我的github上了: https://github.com/zhangfaen/ML/tree/master/neural_network

 详细的参考:http://en.wikipedia.org/wiki/Artificial_neural_network 。 另外,引用一张来自Andrew Moore的slide,该slide深刻的描述了BP的本质


我今天自己手工推导了一遍,涉及到的主要知识点是:复合函数求偏微分。


完成了推导,我们开始实现。

实现一个Java类

package faen;

import java.util.Arrays;
import java.util.Random;

// http://en.wikipedia.org/wiki/Artificial_neural_network
// A kind of non linear model of Machine Learning.
public class NN {
    static class Util {
        public static void CHECK(boolean condition, String message) {
            if (!condition) {
                throw new RuntimeException(message);
            }
        }
    }

    private int expandedInputNodes;
    private int hiddenNodes;
    private int outputNodes;

    // Weights matrix between input layer and hidden layer
    private double[][] wi;
    // Weights matrix between hidden layer and output layer.
    private double[][] wo;

    // last change in weights for momentum
    private double[][] wi_momentum;
    // last change in weights for momentum
    private double[][] wo_momentum;

    // Expanded instance, whose size is this.outputSize + 1.
    // The last element will be fixed to 1.0
    private double[] expandedInstance;

    private double[] hiddenActivations;
    private double[] outputActivations;

    // The sigmoid function: s(x) = 1 / (1 + (e^-x))
    // The derivative of s(x): s(x) * (1 - s(x))
    private double s(double x) {
        return 1.0 / (1.0 + Math.pow(Math.E, -x));
    }

    public NN(int featuresOfInstance, int nodesOfHiddenLayer, int nodesOfOutputLayer) {
        Util.CHECK(featuresOfInstance > 0, "");
        Util.CHECK(nodesOfHiddenLayer > 0, "");
        Util.CHECK(nodesOfOutputLayer > 0, "");
        this.expandedInputNodes = featuresOfInstance + 1;
        this.hiddenNodes = nodesOfHiddenLayer;
        this.outputNodes = nodesOfOutputLayer;
        this.wi = new double[this.expandedInputNodes][this.hiddenNodes];
        this.wo = new double[this.hiddenNodes][
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值