转载请声明出处http://blog.csdn.net/zhongkejingwang/article/details/44514073
上一篇文章介绍了KNN分类器,当时说了其分类效果不是很出色但是比较稳定,本文后面将利用BP网络同样对Iris数据进行分类。
什么是BP网络
BP神经网络,BP即Back Propagation的缩写,也就是反向传播的意思,顾名思义,将什么反向传播?文中将会解答。不仅如此,关于隐层的含义文中也会给出个人的理解。最后会用Java实现的BP分类器作为其应用以加深印象。
很多初学者刚接触神经网络的时候都会到网上找相关的介绍,看了很多数学原理之后还是云里雾里,然后会琢磨到底这个有什么用?怎么用?于是又到网上找别人写的代码,下下来之后看一眼发现代码写的很糟糕,根本就理不清,怎么看也看不懂,于是就放弃了。作为过来人,本人之前在网上也看过很多关于BP网络的介绍,也下载了别人实现的代码下来研究,原理都一样,但是至今为止没有看到过能令人满意的代码实现。于是就有了这篇文章,不仅有原理也有代码,对节点的高度抽象会让代码更有可读性。
CSDN博客编辑器终于可以编写数学公式了!第一次使用Markdown编辑器,感觉爽歪歪,latex数学公式虽然写起来麻烦,不过很灵活,排版也漂亮~在这里贴一个Markdown输入数学公式的教程http://ttang.name/2014/05/04/markdown-and-mathjax/很全的说!
BP网络的数学原理
下面将介绍BP网络的数学原理,相比起SVD的算法推导,这个简直就是小菜一碟,不就是梯度吗求个导就完事了。首先来看看BP网络长什么样,这就是它的样子:
为了简单起见,这里只介绍只有一个隐层的BP网络,多个隐层的也是一样的原理。这个网络的工作原理应该很清楚了,首先,一组输入
x1、x2、…、xm
来到输入层,然后通过与隐层的连接权重产生一组数据
s1、s2、…、sn
作为隐层的输入,然后通过隐层节点的
θ(⋅)
激活函数后变为
θ(sj)
其中
sj
表示隐层的第
j
个节点产生的输出,这些输出将通过隐层与输出层的连接权重产生输出层的输入,这里输出层的处理过程和隐层是一样的,最后会在输出层产生输出
前面提到激活函数
θ(⋅)
,一般使用S形函数(即sigmoid函数),比如可以使用log-sigmoid:
θ(s)=11+e−s
或者tan-sigmoid:
θ(s)=es−e−ses+e−s
前面说了,既然在输出层产生输出了,那总得看下输出结果对不对吧或者距离预期的结果有多大出入吧?现在就来分析一下什么东西在影响输出。显然,输入的数据是已知的,变量只有那些个连接权重了,那这些连接权重如何影响输出呢?现在假设输入层第i个节点到隐层第j个节点的连接权重发生了一个很小的变化
Δwij
,那么这个
Δwij
将会对
sj
产生影响,导致
sj
也出现一个变化
Δsj
,然后产生
Δθ(sj)
,然后传到各个输出层,最后在所有输出层都产生一个误差
Δe
。所以说,权重的调整将会使得输出结果产生变化,那么如何使这些输出结果往正确方向变化呢?这就是接下来的任务:如何调整权重。对于给定的训练样本,其正确的结果已经知道,那么由输入经过网络的输出和正确的结果比较将会有一个误差,如果能把这个误差将到最小,那么就是输出结果靠近了正确结果,就可以说网络可以对样本进行正确分类了。怎样使得误差最小呢?首先,把误差表达式写出来,为了使函数连续可导,这里最小化均方根差,定义损失函数如下:
用什么方法最小化 L ?跟SVD算法一样,用随机梯度下降。也就是对每个训练样本都使权重往其负梯度方向变化。现在的任务就是求
用
由于
所以
接下来只需求出 ∂L∂s1j 即可。
由于 s1j 对所有输出层都有影响,所以
由于
代入前面的式子可得
现在记
输出层 δ 为
到这一步,可以看到是什么反向传播了吧?没错,就是误差 e !
反向传播过程是这样的:输出层每个节点都会得到一个误差
现在再来看第一层权重的梯度:
第二层权重梯度:
可以看到一个规律: 每个权重的梯度都等于与其相连的前一层节点的输出(即 xi 和 θ(s1i) )乘以与其相连的后一层的反向传播的输出(即 δ1j 和 δ2j )。如果看不明白原理的话记住这句话即可!
这样反向传播得到所有的 δ 以后,就可以更新权重了。更直观的BP神经网络的工作过程总结如下:
上图中每一个节点的输出都和权重矩阵中同一列(行)的元素相乘,然后同一行(列)累加作为下一层对应节点的输入。
为了代码实现的可读性,对节点进行抽象如下:
这样的话,很多步骤都在节点内部进行了。
当 θ(s)=11+e−s 时,
当 θ(s)=es−e−ses+e−s 时,
BP网络原理部分就到这,接下来要根据上图中的神经元模型用代码实现BP网络,然后对Iris数据集进行分类。完整的代码见github: https://github.com/jingchenUSTC/ANN
BP网络算法实现
首先,单个神经元封装代码如下:
//NetworkNode.java
package com.jingchen.ann;
public class NetworkNode
{
public static final int TYPE_INPUT = 0;
public static final int TYPE_HIDDEN = 1;
public static final int TYPE_OUTPUT = 2;
private int type;
public void setType(int type)
{
this.type = type;
}
// 节点前向输入输出值
private float mForwardInputValue;
private float mForwardOutputValue;
// 节点反向输入输出值
private float mBackwardInputValue;
private float mBackwardOutputValue;
public NetworkNode()
{
}
public NetworkNode(int type)
{
this.type = type;
}
/**
* sigmoid函数,这里用tan-sigmoid,经测试其效果比log-sigmoid好!
*
* @param in
* @return
*/
private float forwardSigmoid(float in)
{
switch (type)
{
case TYPE_INPUT:
return in;
case TYPE_HIDDEN:
case TYPE_OUTPUT:
return tanhS(in);
}
return 0;
}
/**
* log-sigmoid函数
*
* @param in
* @return
*/
private float logS(float in)
{
return (float) (1 / (1 + Math.exp(-in)));
}
/**
* log-sigmoid函数的导数
*
* @param in
* @return
*/
private float logSDerivative(float in)
{
return mForwardOutputValue * (1 - mForwardOutputValue) * in;
}
/**
* tan-sigmoid函数
*
* @param in
* @return
*/
private float tanhS(float in)
{
return (float) ((Math.exp(in) - Math.exp(-in)) / (Math.exp(in) + Math
.exp(-in)));
}
/**
* tan-sigmoid函数的导数
*
* @param in
* @return
*/
private float tanhSDerivative(float in)
{
return (float) ((1 - Math.pow(mForwardOutputValue, 2)) * in);
}
/**
* 误差反向传播时,激活函数的导数
*
* @param in
* @return
*/
private float backwardPropagate(float in)
{
switch (type)
{
case TYPE_INPUT:
return in;
case TYPE_HIDDEN:
case TYPE_OUTPUT:
return tanhSDerivative(in);
}
return 0;
}
public float getForwardInputValue()
{
return mForwardInputValue;
}
public void setForwardInputValue(float mInputValue)
{
this.mForwardInputValue = mInputValue;
setForwardOutputValue(mInputValue);
}
public float getForwardOutputValue()
{
return mForwardOutputValue;
}
private void setForwardOutputValue(float mInputValue)
{
this.mForwardOutputValue = forwardSigmoid(mInputValue);
}
public float getBackwardInputValue()
{
return mBackwardInputValue;
}
public void setBackwardInputValue(float mBackwardInputValue)
{
this.mBackwardInputValue = mBackwardInputValue;
setBackwardOutputValue(mBackwardInputValue);
}
public float getBackwardOutputValue()
{
return mBackwardOutputValue;
}
private void setBackwardOutputValue(float input)
{
this.mBackwardOutputValue = backwardPropagate(input);
}
}
然后就是整个神经网络类:
//AnnClassifier.java
package com.jingchen.ann;
import java.util.ArrayList;
import java.util.List;
/**
* 人工神经网络分类器
*
* @author chenjing
*
*/
public class AnnClassifier
{
private int mInputCount;
private int mHiddenCount;
private int mOutputCount;
private List<NetworkNode> mInputNodes;
private List<NetworkNode> mHiddenNodes;
private List<NetworkNode> mOutputNodes;
private float[][] mInputHiddenWeight;
private float[][] mHiddenOutputWeight;
private List<DataNode> trainNodes;
public void setTrainNodes(List<DataNode> trainNodes)
{
this.trainNodes = trainNodes;
}
public AnnClassifier(int inputCount, int hiddenCount, int outputCount)
{
trainNodes = new ArrayList<DataNode>();
mInputCount = inputCount;
mHiddenCount = hiddenCount;
mOutputCount = outputCount;
mInputNodes = new ArrayList<NetworkNode>();
mHiddenNodes = new ArrayList<NetworkNode>();
mOutputNodes = new ArrayList<NetworkNode>();
mInputHiddenWeight = new float[inputCount][hiddenCount];
mHiddenOutputWeight = new float[mHiddenCount][mOutputCount];
}
/**
* 更新权重,每个权重的梯度都等于与其相连的前一层节点的输出乘以与其相连的后一层的反向传播的输出
*/
private void updateWeights(float eta)
{
//更新输入层到隐层的权重矩阵
for (int i = 0; i < mInputCount; i++)
for (int j = 0; j < mHiddenCount; j++)
mInputHiddenWeight[i][j] -= eta
* mInputNodes.get(i).getForwardOutputValue()
* mHiddenNodes.get(j).getBackwardOutputValue();
//更新隐层到输出层的权重矩阵
for (int i = 0; i < mHiddenCount; i++)
for (int j = 0; j < mOutputCount; j++)
mHiddenOutputWeight[i][j] -= eta
* mHiddenNodes.get(i).getForwardOutputValue()
* mOutputNodes.get(j).getBackwardOutputValue();
}
/**
* 前向传播
*/
private void forward(List<Float> list)
{
// 输入层
for (int k = 0; k < list.size(); k++)
mInputNodes.get(k).setForwardInputValue(list.get(k));
// 隐层
for (int j = 0; j < mHiddenCount; j++)
{
float temp = 0;
for (int k = 0; k < mInputCount; k++)
temp += mInputHiddenWeight[k][j]
* mInputNodes.get(k).getForwardOutputValue();
mHiddenNodes.get(j).setForwardInputValue(temp);
}
// 输出层
for (int j = 0; j < mOutputCount; j++)
{
float temp = 0;
for (int k = 0; k < mHiddenCount; k++)
temp += mHiddenOutputWeight[k][j]
* mHiddenNodes.get(k).getForwardOutputValue();
mOutputNodes.get(j).setForwardInputValue(temp);
}
}
/**
* 反向传播
*/
private void backward(int type)
{
// 输出层
for (int j = 0; j < mOutputCount; j++)
{
//输出层计算误差把误差反向传播,这里-1代表不属于,1代表属于
float result = -1;
if (j == type)
result = 1;
mOutputNodes.get(j).setBackwardInputValue(
mOutputNodes.get(j).getForwardOutputValue() - result);
}
// 隐层
for (int j = 0; j < mHiddenCount; j++)
{
float temp = 0;
for (int k = 0; k < mOutputCount; k++)
temp += mHiddenOutputWeight[j][k]
* mOutputNodes.get(k).getBackwardOutputValue();
}
}
public void train(float eta, int n)
{
reset();
for (int i = 0; i < n; i++)
{
for (int j = 0; j < trainNodes.size(); j++)
{
forward(trainNodes.get(j).getAttribList());
backward(trainNodes.get(j).getType());
updateWeights(eta);
}
}
}
/**
* 初始化
*/
private void reset()
{
mInputNodes.clear();
mHiddenNodes.clear();
mOutputNodes.clear();
for (int i = 0; i < mInputCount; i++)
mInputNodes.add(new NetworkNode(NetworkNode.TYPE_INPUT));
for (int i = 0; i < mHiddenCount; i++)
mHiddenNodes.add(new NetworkNode(NetworkNode.TYPE_HIDDEN));
for (int i = 0; i < mOutputCount; i++)
mOutputNodes.add(new NetworkNode(NetworkNode.TYPE_OUTPUT));
for (int i = 0; i < mInputCount; i++)
for (int j = 0; j < mHiddenCount; j++)
mInputHiddenWeight[i][j] = (float) (Math.random() * 0.1);
for (int i = 0; i < mHiddenCount; i++)
for (int j = 0; j < mOutputCount; j++)
mHiddenOutputWeight[i][j] = (float) (Math.random() * 0.1);
}
public int test(DataNode dn)
{
forward(dn.getAttribList());
float result = 2;
int type = 0;
//取最接近1的
for (int i = 0; i < mOutputCount; i++)
if ((1 - mOutputNodes.get(i).getForwardOutputValue()) < result)
{
result = 1 - mOutputNodes.get(i).getForwardOutputValue();
type = i;
}
return type;
}
}
Iris数据有三种类别,所以输出层会有三个节点,每个节点代表一种类别,节点输出1(具体根据所用激活函数的上界)则表示属于该类,输出-1(具体根据所用激活函数的下界)则表示不属于该类。
完整的代码已共享到github,地址:https://github.com/jingchenUSTC/ANN。用BP网络对Iris数据进行分类的准确率接近100%!