学习来源:日撸 Java 三百行(71-80天,BP 神经网络)_闵帆的博客——CSDN博客
通用BP神经网络 (3. 综合测试)
实现完整的ANN,网络结构如下图:
1)完整ANN包含四个单层ANN,输入层和输出层节点数为4和3,两个隐含层的节点数为8;
2)激活函数统一使用Sigmoid函数;
3)前向传播时上一层的输出作为下一层的输入;反向传播时从输出层开始,逐层从后往前更新权值矩阵。
代码如下:
package JavaDay23;
import JavaDay21.GeneralAnn;
/**
* Full ANN with a number of layers.
*
* @author Ke-Xiong Wang
*/
public class FullAnn extends GeneralAnn {
/**
* The layers.
*/
AnnLayer[] layers;
/**
********************
* The first constructor.
*
* @param paraFilename
* The arff filename.
* @param paraLayerNumNodes
* The number of nodes for each layer (may be different).
* @param paraLearningRate
* Learning rate.
* @param paraMobp
* Momentum coefficient.
* @param paraActivators The storing the activators of each layer.
********************
*/
public FullAnn(String paraFilename, int[] paraLayerNumNodes, double paraLearningRate,
double paraMobp, String paraActivators) {
super(paraFilename, paraLayerNumNodes, paraLearningRate, paraMobp);
// Initialize layers.
layers = new AnnLayer[numLayers - 1];
for (int i = 0; i < layers.length; i++) {
layers[i] = new AnnLayer(layerNumNodes[i], layerNumNodes[i + 1], paraActivators.charAt(i), paraLearningRate,
paraMobp);
} // Of for i
}// Of the first constructor
/**
********************
* Forward prediction.
*
* @param paraInput
* The input data of one instance.
* @return The data at the output end.
********************
*/
public double[] forward(double[] paraInput) {
double[] resultArray = paraInput;
for(int i = 0; i < numLayers - 1; i ++) {
resultArray = layers[i].forward(resultArray);
}//Of for i
return resultArray;
}// Of forward
/**
********************
* Back propagation.
*
* @param paraTarget
* For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].
*
********************
*/
public void backPropagation(double[] paraTarget) {
double[] tempErrors = layers[numLayers - 2].getLastLayerErrors(paraTarget);
for (int i = numLayers - 2; i >= 0; i--) {
tempErrors = layers[i].backPropagation(tempErrors);
}//Of for i
return;
}// Of backPropagation
/**
********************
* Show me.
********************
*/
public String toString() {
String resultString = "I am a full ANN with " + numLayers + " layers";
return resultString;
}// Of toString
/**
********************
* Test the algorithm.
********************
*/
public static void main(String[] args) {
int[] tempLayerNodes = { 4, 8, 8, 3 };
FullAnn tempNetwork = new FullAnn("D:/iris.arff", tempLayerNodes, 0.01,
0.6, "sss");
for (int round = 0; round < 5000; round++) {
tempNetwork.train();
} // Of for n
double tempAccuray = tempNetwork.test();
System.out.println("The accuracy is: " + tempAccuray);
System.out.println("FullAnn ends.");
}// Of main
}// Of class FullAnn
运行结果: