package classificationModel;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Set;
import tool.DataDealing;
import tool.Function;
public class BP_Deep {//第一层为输入层,最后一层为输出层,其余为隐含层
private double[][] netNode;//【层数】【层节点数】
private double[][] nodeErrors;//【层数】【层节点数】
private double[][][] nodeWeight;//【层数】【起始结点】【终向节点】
private double[][][] nodeWDelta;//【层数】【起始结点】【终向节点】,值为上一次的权值调整量
private double mobp;//权值调整系数
private double rate;//学习步长系数
//权值的[layer][na+1][nb]为偏置量,考虑设置为1.0
public BP_Deep(int[] nodeNumPerLayer,double m,double r) {
this.mobp=m;
this.rate=r;
int layers=nodeNumPerLayer.length;
netNode=new double[layers][];
nodeErrors=new double[layers][];
nodeWeight=new double[layers][][];
nodeWDelta=new double[layers][][];
for(int i=0;i<layers;++i) {
netNode[i]=new double[nodeNumPerLayer[i]];
nodeErrors[i]=new double[nodeNumPerLayer[i]];
if(i+1<layers) {//输出层不算权值
nodeWeight[i]=new double[nodeNumPerLayer[i]+1][nodeNumPerLayer[i+1]];
nodeWDelta[i]=new double[nodeNumPerLayer[i]+1][nodeNumPerLayer[i+1]];
for(int j=0;j<nodeNumPerLayer[i];++j)
for(int k=0;k<nodeNumPerLayer[i+1];++k)
nodeWeight[i][j][k]=Math.random();
for(int k=0;k<nodeNumPerLayer[i+1];++k) nodeWeight[i][nodeNumPerLayer[i]][k]=1.0;
}
}
}
public BP_Deep(ArrayList<ArrayList<String>> trainingSet,double[][] data,double[][] target,int n,boolean b) throws IOException {//从本地读取训练参数
//if(!readParameter(path)) {
final int col=trainingSet.get(0).size()-1;
int t=1;//隐含层的数目
int[] layers=new int[t+2];
layers[0]=col;
layers[t+1]=1;
for(int i=1;i<t+1;++i) layers[i]=col-i*(col-2)/(t+1);
//double[][] target=new double[row][1];
//double[][] data=new double[trainingSet.size()][col];
DataDealing.dataTransfer(trainingSet, data, target, b);
init(layers, 0.15, 0.8);
//迭代训练10000次
trainModel(data, target, n);
//}
}
public void init(int[] nodeNumPerLayer,double m,double r) {
this.mobp=m;
this.rate=r;
int layers=nodeNumPerLayer.length;
netNode=new double[layers][];
nodeErrors=new double[layers][];
nodeWeight=new double[layers][][];
nodeWDelta=new double[layers][][];
for(int i=0;i<layers;++i) {
netNode[i]=new double[nodeNumPerLayer[i]];
nodeErrors[i]=new double[nodeNumPerLayer[i]];
if(i+1<layers) {//输出层不算权值
nodeWeight[i]=new double[nodeNumPerLayer[i]+1][nodeNumPerLayer[i+1]];
nodeWDelta[i]=new double[nodeNumPerLayer[i]+1][nodeNumPerLayer[i+1]];
for(int j=0;j<nodeNumPerLayer[i];++j)
for(int k=0;k<nodeNumPerLayer[i+1];++k)
nodeWeight[i][j][k]=Math.random();
for(int k=0;k<nodeNumPerLayer[i+1];++k) nodeWeight[i][nodeNumPerLayer[i]][k]=1.0;
}
}
}
public double[] computeOutput(double[] inp) {//逐层计算输出
for(int j=0;j<netNode[0].length;++j) netNode[0][j]=inp[j];
for(int i=1;i<netNode.length;++i)
for(int j=0;j<netNode[i].length;++j) {
double z=nodeWeight[i-1][netNode[i-1].length][j];//首先获取偏置量
for(int k=0;k<netNode[i-1].length;++k) z+=nodeWeight[i-1][k][j]*netNode[i-1][k];
netNode[i][j]=Function.sigmoid(z);
}
return netNode[netNode.length-1];
}
//逐层反向反复计算误差并修改权值
public void backPropagation(double[] target) {
backComputeErrors(target);
updateWeight();
}
private void backComputeErrors(double[] target) {//反向计算各层误差
int i=netNode.length-1;//当前层,即输出层
//计算各层的误差
for(int j=0;j<nodeErrors[i].length;++j)
nodeErrors[i][j]=netNode[i][j]*(1.0-netNode[i][j])*(target[j]-netNode[i][j]);
while(i-->0) {
for(int j=0;j<nodeErrors[i].length;++j) {
double z=0.0;
for(int k=0;k<nodeErrors[i+1].length;++k)
z+=nodeWeight[i][j][k]*nodeErrors[i+1][k];
nodeErrors[i][j]=z*netNode[i][j]*(1.0-netNode[i][j]);
}
}
}
private void updateWeight() {
for(int i=0;i<netNode.length-1;++i)
for(int j=0;j<nodeErrors[i].length;++j)
for(int k=0;k<nodeErrors[i+1].length;++k) {
nodeWDelta[i][j][k]=mobp*nodeWDelta[i][j][k]+rate*nodeErrors[i+1][k]*netNode[i][j];
nodeWeight[i][j][k]+=nodeWDelta[i][j][k];
if(j==nodeErrors[i].length-1) {//截矩调整
nodeWDelta[i][j+1][k]=mobp*nodeWDelta[i][j+1][k]+rate*nodeErrors[i+1][k];
nodeWeight[i][j+1][k]+=nodeWDelta[i][j+1][k];
}
}
}
public void train(double[] inp,double[] target) {
computeOutput(inp);
backPropagation(target);
}
public void trainModel(double[][] inp,double[][] tar,int max) {
for(int n=1;n<=max;n++)
for(int i=0;i<inp.length;i++)
train(inp[i], tar[i]);
}
//训练参数存档:dataSet_para.txt
public void stayParameter(String path) throws IOException {
File file=new File("_"+path);
if(file.exists()) file.delete();
file.createNewFile();
FileWriter fileWriter=new FileWriter(file,true);
String str="";
for(int i=0;i<netNode.length-1;++i) str+=String.valueOf(netNode[i].length)+",";
str+=String.valueOf(netNode[netNode.length-1].length);
str+="\n";
fileWriter.write(str);
fileWriter.write(String.valueOf(mobp)+"\n");
fileWriter.write(String.valueOf(rate)+"\n");
//写权值 一层一行
for(double[][] layer : nodeWeight) {
if(layer==null) continue;
str="";
for(int j=0;j<layer.length;++j)
for(int k=0;k<layer[j].length;++k) {
str+=String.valueOf(layer[j][k]);
if(j!=layer.length-1||k!=layer[j].length-1) str+=",";
}
fileWriter.write(str.trim()+"\n");
}
fileWriter.close();
}
//读取保存的训练参数
public boolean readParameter(String path) throws IOException {
File file=new File("_"+path);
if(!file.exists()||!file.isFile()) {
System.out.println("参数读取失败,执行默认操作...");
return false;
}
BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(file)));
String str = reader.readLine();//每层参数
String[] token=str.split(",");
int[] nodeNumPerLayer=new int[token.length];
int i=0;
for(String s : token) nodeNumPerLayer[i++]=Integer.parseInt(s);
str=reader.readLine();//mobp
mobp=Double.parseDouble(str);
str=reader.readLine();//rate
rate=Double.parseDouble(str);
//创建矩阵
init(nodeNumPerLayer, mobp, rate);
//权值
for(i=0;i<nodeNumPerLayer.length-1;++i) {
str=reader.readLine();
String[] tokenW=str.split(",");
int p=0;
for(int j=0;j<nodeWeight[i].length;++j)
for(int k=0;k<nodeWeight[i][j].length;++k)
nodeWeight[i][j][k]=Double.parseDouble(tokenW[p++]);
}
reader.close();
return true;
}
public double reportModel(double[][] inp,double[][] tar) {
Set<Double> set=new HashSet<Double>();
for(double[] xx : tar) set.add(xx[0]);
int count=0;
for(int i=0;i<inp.length;++i) {
double[] result = computeOutput(inp[i]);
if(selectPredict(result[0], set)==tar[i][0]) ++count;
System.out.println(tar[i][0]+"\t"+result[0]);
}
return count/(double)inp.length;
}
private double selectPredict(double val,Set<Double> set) {
double err=Double.MAX_VALUE;
double result=0.0;
for(double x : set)
if(Math.abs(val-x)<err) {
err=Math.abs(val-x);
result=x;
}
return result;
}
}
package classificationModel;
import java.io.IOException;
import java.util.ArrayList;
import tool.ReadData;
public class BPTest {
public static void main(String[] args) throws IOException {
// TODO 自动生成的方法存根
//divorce.txt,AutismAdultDataPlus.txt,StudentAcademicsPerformance.txt
String path="divorce.txt";
boolean bb=false;
ArrayList<ArrayList<String>> trainingSet=ReadData.readDataFileList(path);
final int row=trainingSet.size(),col=trainingSet.get(0).size()-1;
double[][] target=new double[row][1];
double[][] inp=new double[trainingSet.size()][col];
BP_Deep bpDeep=new BP_Deep(trainingSet,inp,target,10000,bb);
System.out.println("模型准确率:"+bpDeep.reportModel(inp, target));
bpDeep.stayParameter(path);
}
}
参考链接:
深度神经网络总结
BP神经网络的Java实现