工作中时不时的会使用机器学习解决一些分类问题。但是一般都是使用已有的机器学习库,比如Weka,Scikit-learn等简单易用的库。
对于一个工程师, 够理解常见的机器学习模型(比如SVM,Naive bayes, Random forest, Neural network, Decision tree, 等等) + 降维技术(比如PCA, SVD, 等),使用已有的库, 来解决常见的工程问题。达到这个境界是不是就够了呢?
我个人觉得不是很满足。对于一些模型,已经训练模型的算法,虽然我看过不少书籍和文章,对于其中奥秘,自我感觉能够理解。但是作为工程师,不是自己实现的,用起来总是惴惴不安。或者自己能够实现过,再用现成的工业级别的库,也会心安理得一点。
那好,这篇文章,我就来用Java实现经典的Neural Network。训练的算法使用经典的Back Propagation算法。简单说两句BP算法:该算法使用最速梯度下降来求解目标函数的最小值。改目标函数是非凸的,因此使用梯度下降,容易求得次优解。一些解决办法包括:使用物理上冲量概念,当落入次优解的时候,算法本身有一定能力冲出次优解,再次滑向最优解。
本篇涉及到的所有code都放到我的github上了: https://github.com/zhangfaen/ML/tree/master/neural_network
详细的参考:http://en.wikipedia.org/wiki/Artificial_neural_network 。 另外,引用一张来自Andrew Moore的slide,该slide深刻的描述了BP的本质
我今天自己手工推导了一遍,涉及到的主要知识点是:复合函数求偏微分。
完成了推导,我们开始实现。
实现一个Java类
package faen;
import java.util.Arrays;
import java.util.Random;
// http://en.wikipedia.org/wiki/Artificial_neural_network
// A kind of non linear model of Machine Learning.
public class NN {
static class Util {
public static void CHECK(boolean condition, String message) {
if (!condition) {
throw new RuntimeException(message);
}
}
}
private int expandedInputNodes;
private int hiddenNodes;
private int outputNodes;
// Weights matrix between input layer and hidden layer
private double[][] wi;
// Weights matrix between hidden layer and output layer.
private double[][] wo;
// last change in weights for momentum
private double[][] wi_momentum;
// last change in weights for momentum
private double[][] wo_momentum;
// Expanded instance, whose size is this.outputSize + 1.
// The last element will be fixed to 1.0
private double[] expandedInstance;
private double[] hiddenActivations;
private double[] outputActivations;
// The sigmoid function: s(x) = 1 / (1 + (e^-x))
// The derivative of s(x): s(x) * (1 - s(x))
private double s(double x) {
return 1.0 / (1.0 + Math.pow(Math.E, -x));
}
public NN(int featuresOfInstance, int nodesOfHiddenLayer, int nodesOfOutputLayer) {
Util.CHECK(featuresOfInstance > 0, "");
Util.CHECK(nodesOfHiddenLayer > 0, "");
Util.CHECK(nodesOfOutputLayer > 0, "");
this.expandedInputNodes = featuresOfInstance + 1;
this.hiddenNodes = nodesOfHiddenLayer;
this.outputNodes = nodesOfOutputLayer;
this.wi = new double[this.expandedInputNodes][this.hiddenNodes];
this.wo = new double[this.hiddenNodes][