第一次写博客,代码也是很拙劣,但大多是自己完成的,算是作为自己开始机器学习之旅的开始吧。
part1. BP算法的原理理解
BP算法的原理可以参考下篇文章,感谢这位博主的分享。
Java实现ANN神经网络之BP代码参考
BP算法的主要流程:
①网络各层节点都有各自的输入权值——n个输入记为X,n个权值记为W,
将e=(X,W)(列向量内积)代入一定的激活函数进行运算,将计算结果前向传播(作为输入传至下一层)。
②输出层得到计算结果,和训练集的输出结果对比得到将误差Δ,将误差反向传播,根据相应的函数关系更新节点的权值,即可开启下一轮训练。
part2. BP算法的JAVA实现思路
根据对BP算法网络的结构,我准备了三个数据类,将计算的功能下分到各个类的成员函数里。分别是:
①节点类BPNode
成员变量:激活函数的系数a;节点输入的计算权重值weight
成员函数:接收输入根据节点权值计算出输出FunctionOfNode
根据反馈的节点误差更新节点权值WeightUpload
②层类BPLayer
成员变量:节点列表AllNode
节点误差列表Layerdelta
节点输入,输出列表Layerinput,Layeroutput
成员函数:接收误差值根据节点权值计算出对该节点下层各个节点的误差贡献ErrDown
得到层的向后一层传递的误差deltaDown
计算层输出computeOut
③网络类BPNet
成员变量:网络的层列表allLayer
网络期望输出ExpectOut
网络输出误差errOfnet
学习率eta
成员函数:
计算输出误差ErrOfouty
将数据集归一化并记录最大最小值SetNormalize
根据输入参数建立对应的网络loadBPnet
④训练类
成员函数:
利用以上三种数据类的函数进行训练settrain
⑤工具函数类
归一化函数normalize,normalize2D
反归一化Antinormalize,Antinormalize2D
part.3 代码实现
①节点类BPNode:
`package BPNetMember;
import java.util.ArrayList;
import ToolFunction.DataProcessing;
public class BPNode {
/********成员变量********************/
private final static double a = 1;
private ArrayList<Double> weight = new ArrayList<Double>();
/************成员函数*****************/
//1.构造器
public BPNode(int n) {
//参数为节点的输入个数和误差
for(int i = 0 ; i < n ; i++) {
this.weight.add(Math.random());
}
}
public BPNode() {
}
//2.getter和setter
public ArrayList<Double> getWeight() {
return weight;
}
public double[]getWeighttoA() {
double[] tempW = new double[weight.size()];
for(int i = 0 ; i < weight.size() ; i++) {
tempW[i] = weight.get(i);
}
return tempW;
}
public void setWeight(ArrayList<Double> weight) {
this.weight = weight;
}
public void setWeight(double[] weight) {
for(int i = 0 ; i < weight.length ; i++) {
this.weight.add(weight[i]);
}
}
//3.功能函数
@Override
public String toString() {
return "BPNode [weight=" + weight + "]";
}
/*******单次输入处理计算函数*********/
/*******参数为节点输入(经归一化后)***/
public static double FunctionOfNode(BPNode node,ArrayList<Double> input) {
//input的个数与weight的个数应相等
double temp = 0;
double [] data = DataProcessing.toarray(input);
temp = FunctionOfNode(node,data);
return temp;
}
public static double FunctionOfNode(BPNode node,double[] input) {
//input的个数与weight的个数应相等
double temp = 0;
for(int i = 0 ; i < input.length ; i++) {
temp += input[i] * node.getWeight().get(i);
}
temp = 1 /(1 + Math.exp( -a * temp));//激活函数fe=1/(1+exp(-ax))
return temp;
}
/*******根据反馈误差更新权值*******/
/**参数为误差,节点输出值(上一次),上一次输入值****************/
public static ArrayList<Double> WeightUpload(BPNode node,double delta,ArrayList<Double> input,double fe) {
double [] data = DataProcessing.toarray(input);
ArrayList<Double> WeightToChange = WeightUpload(node,delta,data,fe);
return WeightToChange;
}
public static ArrayList<Double> WeightUpload(BPNode node,double delta,double[] input,double fe) {
ArrayList<Double> WeightToChange = new ArrayList<Double>();
ArrayList<Double> weight = node.getWeight();
for(int i = 0 ; i < node.getWeight().size() ; i++) {
double wi = weight.get(i) + BPNet.eta*delta*input[i]*a*fe*(1-fe);
WeightToChange.add(i, wi);
}
node.setWeight(WeightToChange);
return WeightToChange;
}
}
②层类BPLayer
package BPNetMember;
import java.util.ArrayList;
import java.util.HashMap;
import ToolFunction.DataProcessing;
public class BPLayer {
/********成员变量********************/
private ArrayList<BPNode> AllNode = new ArrayList<BPNode>();
private ArrayList<Double> Layerdelta = new ArrayList<Double>();
private ArrayList<Double> Layerinput = new ArrayList<Double>();
private ArrayList<Double> LayerOutput = new ArrayList<Double>();//使用时一定要顺序存储和节点顺序一致
/************成员函数*****************/
//1.构造器
public BPLayer() {
super();
}
//2.getter和setter
public ArrayList<BPNode> getAllNode() {
return AllNode;
}
public void setAllNode(ArrayList<BPNode> allNode) {
AllNode = allNode;
}
public ArrayList<Double> getLayerOutput() {
return LayerOutput;
}
public void setLayyerOutput(ArrayList<Double> layerOutput) {
LayerOutput = layerOutput;
}
public void addAllNode(BPNode node) {
AllNode.add(node);
}
public ArrayList<Double> getLayerdelta() {
return Layerdelta;
}
public void setLayerdelta(ArrayList<Double> layerdelta) {
Layerdelta = layerdelta;
}
public ArrayList<Double> getLayerinput() {
return Layerinput;
}
public void setLayerinput(ArrayList<Double> layerinput) {
Layerinput = layerinput;
}
public void setLayerinput(double