BP神经网络
BP神经网络是一种多层前馈神经网络,该网络的主要特点是前向预测,误差反向传播。
前向预测:
以上图为例,输入层有个结点,隐含层有
个结点,输出层有
个结点。某输入层结点到某隐含层结点的权值为
。某隐含层结点到某输出层结点的权值为
。
输出层结点代表类别,输出层某结点的输出值最大则代表输入的数据被判断为该类别。
反向传播:
反向传播将误差沿原来的连接返回,修改各结点间的权值。
激活函数:
由于隐含层通常不止一层,因此采用非线性激活函数防止多层隐含层与一层等价。
常用的激活函数有sigmoid函数:
sigmoid函数会把输入的x值映射到0-1的空间内。 输入的x值越大,值越近似于1。
主要流程:
- 初始化权值。
- 前向预测,计算输出层结点的输出值。
- 反向传播,计算各个结点的误差并调整权值。
package knn5;
import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;
import weka.core.Instances;
public abstract class GeneralAnn {
Instances dataset;
int numLayers;
int[] layerNumNodes;
public double learningRate;
public double mobp;
Random random = new Random();
public GeneralAnn(String paraFilename, int[] paraLayerNumNodes, double paraLearningRate,double paraMobp) {
try {
FileReader tempReader = new FileReader(paraFilename);
dataset = new Instances(tempReader);
// The last attribute is the decision class.
dataset.setClassIndex(dataset.numAttributes() - 1);
tempReader.close();
} catch (Exception ee) {
System.out.println("Error occurred while trying to read \'" + paraFilename
+ "\' in GeneralAnn constructor.\r\n" + ee);
System.exit(0);
} // Of try
layerNumNodes = paraLayerNumNodes;
numLayers = layerNumNodes.length;
layerNumNodes[0] = dataset.numAttributes() - 1;
layerNumNodes[numLayers - 1] = dataset.numClasses();
learningRate = paraLearningRate;
mobp = paraMobp;
}//Of the first constructor
public abstract double[] forward(double[] paraInput);
public abstract void backPropagation(double[] paraTarget);
public void train() {
double[] tempInput = new double[dataset.numAttributes() - 1];
double[] tempTarget = new double[dataset.numClasses()];
for (int i = 0; i < dataset.numInstances(); i++) {
for (int j = 0; j < tempInput.length; j++) {
tempInput[j] = dataset.instance(i).value(j);
} // Of for j
Arrays.fill(tempTarget, 0);
tempTarget[(int) dataset.instance(i).classValue()] = 1;
forward(tempInput);
backPropagation(tempTarget);
} // Of for i
}// Of train
public static int argmax(double[] paraArray) {
int resultIndex = -1;
double tempMax = -1e10;
for (int i = 0; i < paraArray.length; i++) {
if (tempMax < paraArray[i]) {
tempMax = paraArray[i];
resultIndex = i;
} // Of if
} // Of for i
return resultIndex;
}// Of argmax
public double test() {
double[] tempInput = new double[dataset.numAttributes() - 1];
double tempNumCorrect = 0;
double[] tempPrediction;
int tempPredictedClass = -1;
for (int i = 0; i < dataset.numInstances(); i++) {
for (int j = 0; j < tempInput.length; j++) {
tempInput[j] = dataset.instance(i).value(j);
} // Of for j
tempPrediction = forward(tempInput);
tempPredictedClass = argmax(tempPrediction);
if (tempPredictedClass == (int) dataset.instance(i).classValue()) {
tempNumCorrect++;
} // Of if
} // Of for i
System.out.println("Correct: " + tempNumCorrect + " out of " + dataset.numInstances());
return tempNumCorrect / dataset.numInstances();
}// Of test
}//Of class GeneralAnn