BP神经网络
时间:2022/8/2
1.神经元模型
在生物神经网络中,最主要的结构便是神经元,如上图所示,便是生物神经元的结构模型,树突感知其他神经元传递的信息,通过轴突向后传播。神经元与神经元之间传播时通过一个突触的结构,通过神经递质改变下一个神经元的电位,当电位超过一定阈值,则信息将通过电位在下一个神经元上进行传递。故我们可将神经元的这种结构特性进行提取,构建神经元模型。
2.M-P神经元模型
M-P神经元模型是1943年由[McCulloch and Pitts,1943]将生物神经元抽象出来的神经元模型,一直沿用至今。在这个模型中,神经元接收n个其他神经元传递的输入,模拟树突。这些输入信号通过带权重的连接进行传递,神经元接受到的总输入值将与神经元的阈值进行比较,然后通过神经元的“激活函数”(ativation function)的处理以产生神经元的输出。虽然说看似是对生物神经元的模拟,但我感觉并不是真的模拟的生物神经元,感觉还是差别挺大的,归根结底,还是一个数学模型。
激活函数常采用阶跃函数和sigmoid函数进行处理。下图便是sigmoid函数的图像,该函数可以将较大的数值变化,转换成较小的区间**[0,1]**里,有时也称之为挤压函数。
3.感知机与多层网络
感知机(Perceptron由两层神经元构成,如下图所示,输入层接收外界输入信号,输出层则是M-P神经元。其输出函数为:
y
=
f
(
∑
i
w
i
x
i
−
θ
)
(1)
y=f(\sum_i{w_ix_i}-\theta) \tag{1}
y=f(i∑wixi−θ)(1)
感知机能轻松实现逻辑与、或、非的运算。
对于感知机的训练过程,一般地,给定训练数据集,权重 w i ( i = 1 , 2 , . . , n ) w_i(i=1,2,..,n) wi(i=1,2,..,n)以及阈值 θ \theta θ可通过学习得到。阈值 θ \theta θ可以看作是一个固定输入-1.0的”哑结点“所对应的连接权重为 w n + 1 w_{n+1} wn+1,这样,将权重学习与阈值学习统一为权重的学习。而感知机的学习规则非常简单,对于训练样本 ( x , y ) (x,y) (x,y),若当前感知机的输出为 y ^ \hat y y^,则感知机的权值调整:
w i ← w i + Δ w i , Δ w i = η ( y − y ^ ) x i (2) w_i\leftarrow w_i+\Delta w_i,\\ \Delta w_i=\eta(y-\hat y)x_i \tag{2} wi←wi+Δwi,Δwi=η(y−y^)xi(2)
欲解决非线性问题,则需要多层功能神经元。如下图的三层神经元网络,在输入层和输出层之间再加入一层神经元,这一层称之为**“隐含层(hidden layer)”** 。通过如图的多层网络,即可解决异或的问题。
更一般的神经网络则是如下图所示的**“多层前馈神经网络”。即每层神经元与下层神经元完全互连,即“全连接”**。同一层神经元不相连,而且不能跨层连接。
输入层神经元只是接受数据输入,不进行任何处理,而隐含层和输出层则包含功能神经元。神经网络的学习过程,就是根据训练数据来调整神经元之间的“连接权”和功能神经元的阈值的过程。
3.误差反向传播算法(BP)
对于多层学习,其学习规则比感知机更为复杂。 而误差反向传播(error BackPropagation)则是最为成功的神经网络算法。如图所示,便是BP算法的网络结构。
3.1 正向传播
对于训练集 ( X k , y k ) (X_k,y_k) (Xk,yk),假定神经网络的输出为 y ^ = ( y ^ 1 k , y ^ 2 k , . . . , y ^ l k ) \hat y =(\hat y _1^k,\hat y _2^k,...,\hat y _l^k) y^=(y^1k,y^2k,...,y^lk).则,设 β j = ∑ w h j x i \beta_j=\sum w_{hj}x_i βj=∑whjxi为第j个输出神经元的输入, θ j \theta_j θj为第j个输出神经元的阈值, w h j w_{hj} whj为隐含层第h个神经元到第j个输出神经元的连接权,激活函数为sigmoid,其输出函数为:
y ^ j k = f ( β j − θ j ) (3) \hat y_j^k=f(\beta_j-\theta_j) \tag{3} y^jk=f(βj−θj)(3)
则网络在该数据集 ( X k , y k ) (X_k,y_k) (Xk,yk)上的方差为:
E k = 1 2 ∑ j = 1 l ( y ^ j k − y j k ) 2 (4) E_k={1\over2}\sum_{j=1}^l(\hat y_j^k-y_j^k)^2\tag{4} Ek=21j=1∑l(y^jk−yjk)2(4)
3.2 反向传播
而BP算法的目标则是最小化网络的均方误差。其采用的优化策略为梯度下降法。根据梯度下降计算出的梯度项进行参数调整。
Δ w h j = − η ∂ E k ∂ w h j (5) \Delta w_{hj}=-\eta\frac{ \partial E_k }{ \partial w_{hj}}\tag{5} Δwhj=−η∂whj∂Ek(5)
则按照梯度下降的计算公式,对于隐含层到输出层的权值调整为:
Δ w h j = η g j b h (6) \Delta w_{hj}=\eta g_jb_h \tag{6} Δwhj=ηgjbh(6)
其中 b h b_h bh为第h个隐含层神经元的输出。 g j g_j gj为梯度项:
g j = − ∂ E k ∂ y ^ j k ∗ y ^ j k ∂ β j = y ^ j k ( 1 − y ^ j k ) ( y j k − y ^ j k ) (7) \begin{split} g_j & =-\frac{ \partial E_k }{ \partial \hat y_j^k}*\frac{ \hat y_j^k}{ \partial \beta _j} \\ &=\hat y_j^k(1-\hat y_j^k)(y_j^k-\hat y_j^k) \end{split}\tag{7} gj=−∂y^jk∂Ek∗∂βjy^jk=y^jk(1−y^jk)(yjk−y^jk)(7)
同理可得:
Δ θ j = − η g j (8) \Delta \theta_j=-\eta g_j \tag{8} Δθj=−ηgj(8)
Δ w i h = η e h x 1 (9) \Delta w_{ih}=\eta e_h x_1 \tag{9} Δwih=ηehx1(9)
Δ γ h = − η e h (10) \Delta \gamma_h=-\eta e_h \tag{10} Δγh=−ηeh(10)
其中 e h e_h eh为隐含层权值调整梯度项,设第h个隐含层的输入为 α h \alpha_h αh,则:
e h = − ∂ E k ∂ b h ∗ b h ∂ α h = b h ( 1 − b h ) ∑ j = 1 l w h i g j (11) \begin{split} e_h & =-\frac{ \partial E_k }{ \partial b_h}*\frac{ b_h}{ \partial \alpha_h} \\ &=b_h(1-b_h)\sum_{j=1}^lw_{hi}g_j \end{split}\tag{11} eh=−∂bh∂Ek∗∂αhbh=bh(1−bh)j=1∑lwhigj(11)
对任意参数v其更新公式为:
v ← v + Δ v (12) v\leftarrow v+\Delta v \tag{12} v←v+Δv(12)
4.算法流程
输入:训练集 D = { ( X k , y k ) } m D=\{(X_k,y_k)\}^m D={(Xk,yk)}m,学习率 η \eta η
过程:
1.在(0,1)的范围内随机初始化网络中的所有参数。
2.repeat
-
for all( ( X k , y k ) ∈ D (X_k,y_k)\in D (Xk,yk)∈D) do
-
正向传播,根据当前参数和输入数据计算当前样本的输出 y ^ j k \hat y_j^k y^jk
-
根据公式(7)计算输出层神经元的梯度项 g j g_j gj
-
根据公式(11)计算隐含层神经元梯度项 e h e_h eh
-
根据公式(6)(8-10)(12)更新连接权与阈值
end for
util 达到停止条件
-
输出:训练好的神经网络
5.算法实现
1.GeneralAnn.java
抽象类,定义ANN的基础结构
/**
* GeneralAnn.java
*
* @Author zjy
* @Date 2022/7/28
* @Description: 抽象类,定义ANN的基础结构
* @Version V1.0
*/
package swpu.zjy.ML.ANN.myAnn;
import weka.core.Instances;
import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;
public abstract class GeneralAnn {
//数据集
Instances dataset;
//神经网络层数
int numLayers;
//每层神经网络结点数量
int[] numLayerNodes;
//动量,用于加速梯度下降
public double mobp;
//梯度下降学习率
public double learningRate;
//随机数生成器
Random random = new Random();
/**
* 构造器,初始化网络参数
*
* @param datasetFileName 数据集地址
* @param paraNumLayerNodes 神经网络每层结点数量
* @param paraMobp 动量
* @param paraLR 学习率
*/
public GeneralAnn(String datasetFileName, int[] paraNumLayerNodes, double paraMobp, double paraLearnRate) {
FileReader fileReader = null;
try {
fileReader = new FileReader(datasetFileName);
dataset = new Instances(fileReader);
dataset.setClassIndex(dataset.numAttributes() - 1);
fileReader.close();
} catch (Exception e) {
System.out.println(e);
}
numLayerNodes = paraNumLayerNodes;
numLayers = numLayerNodes.length;
numLayerNodes[0] = dataset.numAttributes() - 1;
numLayerNodes[numLayers-1]=dataset.numClasses();
learningRate = paraLearnRate;
mobp = paraMobp;
}
/**
* 抽象方法,向前传播,输出预测结果
*
* @param paraInput 输入数据
* @return 预测结果
*/
public abstract double[] forward(double[] paraInput);
/**
* 抽象方法,向后传播,调整网络参数
*
* @param paraTarget 预计目标结果
*/
public abstract void backPropagation(double[] paraTarget);
/**
* 使用数据集进行训练
*/
public void train() {
double[] tempInput = new double[dataset.numAttributes() - 1];
double[] tempTarget = new double[dataset.numClasses()];
for (int i = 0; i < dataset.numInstances(); i++) {
/**
* 填充输入数据
*/
for (int j = 0; j < tempInput.length; j++) {
tempInput[j] = dataset.instance(i).value(j);
}
/**
* 填充预计目标
*/
Arrays.fill(tempTarget, 0);
tempTarget[(int) dataset.instance(i).classValue()] = 1;
forward(tempInput);
backPropagation(tempTarget);
}
}
/**
* 获取最大值下标
*
* @param paraArray 数组
* @return 最大值下标
*/
public static int argmax(double[] paraArray) {
int resultIndex = -1;
double tempMax = -1e10;
for (int i = 0; i < paraArray.length; i++) {
if (tempMax < paraArray[i]) {
tempMax = paraArray[i];
resultIndex = i;
}
}
return resultIndex;
}
/**
* 进行测试
*
* @return AUC
*/
public double test() {
double[] tempInput = new double[dataset.numAttributes() - 1];
double tempNumCorrect = 0;
double[] tempPrediction;
int[] predict=new int[dataset.numInstances()];
Arrays.fill(predict,0);
int tempPredictedClass = -1;
for (int i = 0; i < dataset.numInstances(); i++) {
for (int j = 0; j < tempInput.length; j++) {
tempInput[j] = dataset.instance(i).value(j);
}
tempPrediction = forward(tempInput);
tempPredictedClass = argmax(tempPrediction);
predict[i]=tempPredictedClass;
if (tempPredictedClass == (int) dataset.instance(i).classValue()) {
tempNumCorrect++;
}
}
System.out.println("Correct:" + tempNumCorrect + "out of" + dataset.numInstances());
return tempNumCorrect / dataset.numInstances();
}
}
2.Activator.java
激活函数封装
package swpu.zjy.ML.ANN.myAnn;
/**
* Activator.java
*
* @Author Fan Min minfanphd@163.com.
* @Date 2022/7/29
* @Description: 激活函数
* @Version V1.0
*/
public class Activator {
public final char ARC_TAN='a';
public final char ELU='e';
public final char GELU= 'g';
public final char HARD_LOGISTIC='h';
public final char IDENTITY = 'i';
public final char LEAKY_RELU = 'l';
public final char RELU = 'r';
public final char SOFT_SIGN = 'o';
public final char SIGMOID = 's';
public final char TANH = 't';
public final char SOFT_PLUS = 'u';
public final char SWISH = 'w';
private char activator;
double alpha;
double beta;
double gamma;
/**
* 构造器。设置激活函数类型
* @param activator 激活函数类型
*/
public Activator(char activator) {
this.activator = activator;
}
public char getActivator() {
return activator;
}
public void setActivator(char activator) {
this.activator = activator;
}
public void setAlpha(double alpha) {
this.alpha = alpha;
}
public void setBeta(double beta) {
this.beta = beta;
}
public void setGamma(double gamma) {
this.gamma = gamma;
}
/**
* 激活函数
* @param paraValue 欲激活数据
* @return 激活后的数据
*/
public double activate(double paraValue) {
double resultValue = 0;
switch (activator) {
case ARC_TAN:
resultValue = Math.atan(paraValue);
break;
case ELU:
if (paraValue >= 0) {
resultValue = paraValue;
} else {
resultValue = alpha * (Math.exp(paraValue) - 1);
} // Of if
break;
// case GELU:
// resultValue = ?;
// break;
// case HARD_LOGISTIC:
// resultValue = ?;
// break;
case IDENTITY:
resultValue = paraValue;
break;
case LEAKY_RELU:
if (paraValue >= 0) {
resultValue = paraValue;
} else {
resultValue = alpha * paraValue;
} // Of if
break;
case SOFT_SIGN:
if (paraValue >= 0) {
resultValue = paraValue / (1 + paraValue);
} else {
resultValue = paraValue / (1 - paraValue);
} // Of if
break;
case SOFT_PLUS:
resultValue = Math.log(1 + Math.exp(paraValue));
break;
case RELU:
if (paraValue >= 0) {
resultValue = paraValue;
} else {
resultValue = 0;
} // Of if
break;
case SIGMOID:
resultValue = 1 / (1 + Math.exp(-paraValue));
break;
case TANH:
resultValue = 2 / (1 + Math.exp(-2 * paraValue)) - 1;
break;
// case SWISH:
// resultValue = ?;
// break;
default:
System.out.println("Unsupported activator: " + activator);
System.exit(0);
}// Of switch
return resultValue;
}// Of activate
/**
* 激活函数求导,用于反向传播
* @param paraValue 源数据
* @param paraActivatedValue f(x)
* @return
*/
public double derive(double paraValue,double paraActivatedValue) {
double resultValue = 0;
switch (activator) {
case ARC_TAN:
resultValue = 1 / (paraValue * paraValue + 1);
break;
case ELU:
if (paraValue >= 0) {
resultValue = 1;
} else {
resultValue = alpha * (Math.exp(paraValue) - 1) + alpha;
} // Of if
break;
// case GELU:
// resultValue = ?;
// break;
// case HARD_LOGISTIC:
// resultValue = ?;
// break;
case IDENTITY:
resultValue = 1;
break;
case LEAKY_RELU:
if (paraValue >= 0) {
resultValue = 1;
} else {
resultValue = alpha;
} // Of if
break;
case SOFT_SIGN:
if (paraValue >= 0) {
resultValue = 1 / (1 + paraValue) / (1 + paraValue);
} else {
resultValue = 1 / (1 - paraValue) / (1 - paraValue);
} // Of if
break;
case SOFT_PLUS:
resultValue = 1 / (1 + Math.exp(-paraValue));
break;
case RELU: // Updated
if (paraValue >= 0) {
resultValue = 1;
} else {
resultValue = 0;
} // Of if
break;
case SIGMOID: // Updated
resultValue = paraActivatedValue * (1 - paraActivatedValue);
break;
case TANH: // Updated
resultValue = 1 - paraActivatedValue * paraActivatedValue;
break;
// case SWISH:
// resultValue = ?;
// break;
default:
System.out.println("Unsupported activator: " + activator);
System.exit(0);
}// Of switch
return resultValue;
}
public String toString() {
String resultString = "Activator with function '" + activator + "'";
resultString += "\r\n alpha = " + alpha + ", beta = " + beta + ", gamma = " + gamma;
return resultString;
}
public static void main(String[] args) {
Activator tempActivator = new Activator('s');
double tempValue = 0.6;
double tempNewValue;
tempNewValue = tempActivator.activate(tempValue);
System.out.println("After activation: " + tempNewValue);
tempNewValue = tempActivator.derive(tempValue, tempNewValue);
System.out.println("After derive: " + tempNewValue);
}
}
3.AnnLayer.java
Ann 层的实现
package swpu.zjy.ML.ANN.myAnn;
import java.util.Arrays;
import java.util.Random;
/**
* AnnLayer.java
*
* @Author zjy
* @Date 2022/7/30
* @Description: ANN Layer定义
* @Version V1.0
*/
public class AnnLayer {
//输入数据数量
int numInput;
//输出数据数量
int numOutput;
//学习率
double learningRate;
//动量
double mobp;
//权值
double[][] weights;
//权值改变量
double[][] deltaWeights;
//输出误差
double[] errors;
//输入数据
double[] input;
//输出数据
double[] output;
//激活后的输出数据
double[] activatedOutput;
//激活函数
Activator activator;
//随机数生成器
Random random=new Random();
/**
* 构造器,初始化层相关参数
* @param paraNumInput 输出数据数量
* @param paraNumOutput 输出数据数量
* @param paraActivator 激活函数类型
* @param paraLearningRate 学习率
* @param paraMobp 动量
*/
public AnnLayer(int paraNumInput, int paraNumOutput, char paraActivator,
double paraLearningRate, double paraMobp){
numInput=paraNumInput;
numOutput=paraNumOutput;
learningRate=paraLearningRate;
mobp=paraMobp;
weights=new double[numInput+1][numOutput];
deltaWeights=new double[numInput+1][numOutput];
//初始化权值
for (int i = 0; i < numInput+1; i++) {
for (int j = 0; j < numOutput; j++) {
weights[i][j]=random.nextDouble();
}
}
errors=new double[numInput];
input=new double[numInput];
output=new double[numOutput];
activatedOutput=new double[numOutput];
activator=new Activator(paraActivator);
}
/**
* 设置激活函数参数
* @param paraAlpha
* @param paraBeta
* @param paraGamma
*/
public void setParameters(double paraAlpha, double paraBeta, double paraGamma) {
activator.setAlpha(paraAlpha);
activator.setBeta(paraBeta);
activator.setGamma(paraGamma);
}
/**
* 向前传播
* @param paraInput 输入数据
* @return 输出结果
*/
public double[] forword(double[] paraInput){
for (int i = 0; i < numInput; i++) {
input[i]=paraInput[i];
}
for (int i = 0; i < numOutput; i++) {
//读入偏移量
output[i]=weights[numInput][i];
for (int j = 0; j < numInput; j++) {
output[i]+=input[j]*weights[j][i];
}
activatedOutput[i]=activator.activate(output[i]);
}
return activatedOutput;
}
/**
* 单层反向传播
* @param paraErrors 后一层的传播误差
* @return 向前一层传播的误差
*/
public double[] backPropagation(double[] paraErrors){
//计算该层梯度项
for (int i = 0; i < paraErrors.length; i++) {
paraErrors[i]=activator.derive(output[i],activatedOutput[i])*paraErrors[i];
}
//权值调整
for (int i = 0; i < numInput; i++) {
errors[i]=0;
//更新连接权值
for (int j = 0; j < numOutput; j++) {
errors[i] += paraErrors[j] * weights[i][j];
deltaWeights[i][j]=mobp*deltaWeights[i][j]+learningRate*paraErrors[j]*input[i];
weights[i][j]+=deltaWeights[i][j];
}
}
//更新阈值
for (int j = 0; j < numOutput; j++) {
deltaWeights[numInput][j]=mobp*deltaWeights[numInput][j]+learningRate*paraErrors[j];
weights[numInput][j]+=deltaWeights[numInput][j];
}
return errors;
}
/**
* 获取后一层误差
* @param paraTarget 预计目标
* @return 本次传播误差
*/
public double[] getLastLayerErrors(double[] paraTarget){
double[] resultErrors=new double[numOutput];
for (int i = 0; i < numOutput; i++) {
resultErrors[i]=(paraTarget[i]-activatedOutput[i]);
}
return resultErrors;
}
public String toString() {
String resultString = "";
resultString += "Activator: " + activator;
resultString += "\r\n weights = " + Arrays.deepToString(weights);
return resultString;
}
public static void unitTest() {
swpu.zjy.ML.ANN.teacher.AnnLayer tempLayer = new swpu.zjy.ML.ANN.teacher.AnnLayer(2, 3, 's', 0.01, 0.1);
double[] tempInput = { 1, 4 };
System.out.println(tempLayer);
double[] tempOutput = tempLayer.forward(tempInput);
System.out.println("Forward, the output is: " + Arrays.toString(tempOutput));
double[] tempError = tempLayer.backPropagation(tempOutput);
System.out.println("Back propagation, the error is: " + Arrays.toString(tempError));
}
public static void main(String[] args) {
unitTest();
}
}
4.FullAnn.java
BP神经网络实现
package swpu.zjy.ML.ANN.myAnn;
/**
* FullAnn.java
*
* @Author zjy
* @Date 2022/7/30
* @Description: BP神经网络实现
* @Version V1.0
*/
public class FullAnn extends GeneralAnn{
//神经网络层数
AnnLayer[] layers;
/**
* 构造器,初始化网络参数,构造神经网络结构
*
* @param datasetFileName 数据集地址
* @param paraNumLayerNodes 神经网络每层结点数量
* @param paraMobp 动量
* @param paraLearnRate 学习率
*/
public FullAnn(String datasetFileName, int[] paraNumLayerNodes, double paraMobp, double paraLearnRate,String paraActivator) {
super(datasetFileName, paraNumLayerNodes, paraMobp, paraLearnRate);
//生成层
layers=new AnnLayer[numLayers-1];
for (int i = 0; i < layers.length; i++) {
layers[i]=new AnnLayer(numLayerNodes[i],numLayerNodes[i+1],paraActivator.charAt(i),paraLearnRate,paraMobp);
}
}
/**
* 前向传播
* @param paraInput 输入数据
* @return 预测结果
*/
@Override
public double[] forward(double[] paraInput) {
double[] resultArray=paraInput;
for (int i = 0; i < numLayers-1; i++) {
resultArray=layers[i].forword(resultArray);
}
return resultArray;
}
/**
*
* @param paraTarget 预计目标结果
*/
@Override
public void backPropagation(double[] paraTarget) {
double[] tempErrors=layers[numLayers-2].getLastLayerErrors(paraTarget);
for (int i = numLayers - 2; i >= 0; i--){
tempErrors=layers[i].backPropagation(tempErrors);
}
return;
}
public String toString() {
return "I am a full ANN with " + numLayers + " layers";
}
public static void main(String[] args) {
int[] tempLayerNodes = { 4, 8, 8, 3 };
FullAnn tempNetwork = new FullAnn("src/main/java/swpu/zjy/ML/DataSet/iris.arff", tempLayerNodes, 0.6,0.01,
"sss");
for (int round = 0; round < 5000; round++) {
tempNetwork.train();
} // Of for n
double tempAccuray = tempNetwork.test();
System.out.println("The accuracy is: " + tempAccuray);
System.out.println("FullAnn ends.");
}
}
String() {
return "I am a full ANN with " + numLayers + " layers";
}
public static void main(String[] args) {
int[] tempLayerNodes = { 4, 8, 8, 3 };
FullAnn tempNetwork = new FullAnn("src/main/java/swpu/zjy/ML/DataSet/iris.arff", tempLayerNodes, 0.6,0.01,
"sss");
for (int round = 0; round < 5000; round++) {
tempNetwork.train();
} // Of for n
double tempAccuray = tempNetwork.test();
System.out.println("The accuracy is: " + tempAccuray);
System.out.println("FullAnn ends.");
}
}
5.运行测试