什么是BP神经网络,BP网络是一种按照误差逆向传播算法训练的多层前馈神经网络。基本的BP神经网路包括信号的前向传播和误差的反向传播两个过程。这里我们具体研究的是简单的三层BP网络。这三层分别是,输入层、隐含层和输出层。如下图,输入层有输入X共n个输入,隐含层的神经元个数由自己来定义,为S1个,输出层即是目标输出共S2个。
如上图所示,我们可以得知,输入层到隐含层的加权矩阵为:S1*n,偏置矩阵为S1*1;隐含层到输出层的加权矩阵为S2*S1,偏置矩阵为S2*1。
首先,我们先来定义误差函数:
这是一个样本的误差,一组样本的误差就是所有误差之和去均值即可。
相比与感知器神经元,BP网络在权值变化时引入了一个误差效能δk
输入层的权值变化:
隐含层的权值变化:
这里,公式中有一个函数f,我们称之为激活函数,激活函数可以选用:
我这里使用的是logsig函数。
具体的算法实现过程:
1. 初始化权矩阵W1、W2,阈值向量B1、B2;
2. 初置精度控制参数e,学习率a,精度控制变量d= e+1;(t=0,T迭代次数)
3. While d³ e do
4. d=0;
5. for 每个样本(X,Y) do
6. 输入X,计算隐含层输出A;
7. 隐含层输出A作为输出层的输入,计算输出层的输出O(即模型输出);
8. 计算累积误差:d=d+(yi-oi)2
9. 根据输出层的误差效能计算隐含层的误差;
10. 根据输出层的误差效能修正W2、B2;
11. 根据隐含层的误差效能修正W1、B1(t++)观察t最后的值是根据t还是精度退出的,如果是根据精度退出则收敛,如果根据T则可能修改算法
下面是具体实现:
主流程类BP():
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Scanner;
public class BP {
private ArrayList<Input> inputList = new ArrayList<Input>();
private double[][] weightOne,weightTwo;
private double[] baisOne,baisTwo;
private int hide = 10;
private int inputLong,outLong;
public static double alpha = 0.2;
private double LIMIT = 0.0001;
private int MAX = 1000000;
public static void main(String[] args){
new BP();
}
public BP(){
initMethod(); //读取数据
weightOne = new double[hide][inputLong]; //输入层到隐含层
weightTwo = new double[outLong][hide]; //隐含层到输出层
baisOne = new double[hide]; //输入层到隐含层偏置
baisTwo = new double[outLong]; //隐含层到输出层偏置
for(int i = 0;i < weightOne.length;i ++){ //加权阵初始化
baisOne[i] = Math.random()*2-1;
for(int j = 0;j < weightOne[i].length;j ++){
weightOne[i][j] = Math.random()*2-1;
}
}
for(int i = 0;i < weightTwo.length;i ++){ //加权阵初始化
baisTwo[i] = Math.random()*2-1;
for(int j = 0;j < weightTwo[i].length;j ++){
weightTwo[i][j] = Math.random()*2-1;
}
}
for(int i = 0;i < inputList.size();i ++){
//inputList.get(i).show();
int temp = 0;
temp = (int)(Math.random()*10);
if(temp < 7){
inputList.get(i).setRead(true);
}else
inputList.get(i).setRead(false);
}
doing();
}
private void doing() {
// TODO Auto-generated method stub
double sum = 0;int a = 0,t = 0;
boolean test = true;
int temp = 0;
while(true){
temp++;
test = true;
t=0;
sum = 0;
for(int i = 0;i < inputList.size();i ++){
if(inputList.get(i).isRead()){
inputList.get(i).firstStep(weightOne, baisOne);
inputList.get(i).secondStep(weightTwo, baisTwo);
inputList.get(i).threeStep(weightTwo);
inputList.get(i).fourStep(weightTwo, baisTwo);
inputList.get(i).fiveStep(weightOne, baisOne);
sum += inputList.get(i).getTotalTwo();
t++;
/*System.out.print(sum);
if(sum > LIMIT){
test = false;
System.out.print(" *");
}
System.out.println("");*/
}
}
sum = 1.0/(2*t)*sum;
if(sum < LIMIT){
a = 1;
for(int m = 0;m < inputList.size();m ++){
//System.out.println(inputList.get(m).isRead());
if(!inputList.get(m).isRead()){
inputList.get(m).test(weightOne, weightTwo, baisOne, baisTwo);
inputList.get(m).showAnswer();
}
}
/*for(int n = 0;n <weightOne.length;n ++){
for(int l = 0;l < weightOne[n].length;l ++)
System.out.print(weightOne[n][l]+" ");
System.out.println("");
}
System.out.println("+++++++++++++++++++");
for(int n = 0;n <weightTwo.length;n ++){
for(int l = 0;l < weightTwo[n].length;l ++)
System.out.print(weightTwo[n][l]+" ");
System.out.println("");
}*/
break;
}
if(temp > MAX) {break;}
if(a == 1){
break;
}
}
if(temp > MAX)
System.out.println("尴尬,超出次数了!");
}
private void initMethod() {
// TODO Auto-generated method stub
String testPath = "carb_x.txt";
String resultPath = "carb_t.txt";
try{
File testFile = new File(testPath);
File resultFile = new File(resultPath);
if(testFile.isFile()&&testFile.exists()&&resultFile.isFile()&&resultFile.exists()){
InputStreamReader readOne = new InputStreamReader(new FileInputStream(testFile),"gbk");
InputStreamReader readTwo = new InputStreamReader(new FileInputStream(resultFile),"gbk");
BufferedReader readerOne = new BufferedReader(readOne);
BufferedReader readerTwo = new BufferedReader(readTwo);
String lineOne,lineTwo;
while((lineOne = readerOne.readLine())!=null){
String[] strOne;
strOne = lineOne.split("\t");
Input input = new Input(hide);
input.setData(change(strOne));
inputList.add(input);
inputLong = strOne.length;
}
int m = 0;
while((lineTwo = readerTwo.readLine())!=null){
String[] strTwo;
strTwo = lineTwo.split("\t");
inputList.get(m++).setAim(change(strTwo));
outLong = strTwo.length;
}
}
}catch(Exception e){
e.printStackTrace();
}
}
private double[] change(String[] strOne) {
// TODO Auto-generated method stub
double[] str;
str = new double[strOne.length];
for(int i = 0;i < strOne.length;i ++){
str[i] = Double.parseDouble(strOne[i]);
}
return str;
}
}
输入类型Input():
public class Input {
private double[] data;//输入
private double[] middle;//输入层到隐含层
private double[] borrow;//隐含层到输出层
private double[] tempOne;//未激活的隐含层结果
private double[] tempTwo;//未激活的输出层结果
private double[] aim;//目标
private double[] errorTwo;
private double totalTwo;
private double[] answer;
private double[] one,two; //误差效能
private boolean read = false;//是否用于测试结果,false表示还未用来学习,true表示
//这个输入用来学习,不能用于测试了
//private double[][] w1,w2;
private void clean(){
for(int i = 0;i < tempOne.length;i ++){
tempOne[i] = 0;
middle[i] = 0;
}
for(int i = 0;i < tempTwo.length;i ++){
tempTwo[i] = 0;
errorTwo[i] = 0;
borrow[i] = 0;
}
totalTwo = 0;
for(int i = 0;i < one.length;i ++){
one[i] = 0;
}
for(int i = 0;i < errorTwo.length;i ++){
two[i] = 0;
}
}
public void firstStep(double[][] weightOne,double[] baisOne){ //输入层到隐含层
clean();
/*w1 = new double[weightOne.length][weightOne[0].length];
for(int i = 0;i < w1.length;i ++){
for(int j = 0;j < w1[i].length;j ++){
w1[i][j] = weightOne[i][j];
}
}*/
for(int i = 0;i < weightOne.length;i ++){
tempOne[i] = 0;
for(int j = 0;j < weightOne[i].length;j ++){
tempOne[i] = tempOne[i] + weightOne[i][j]*data[j];
}
tempOne[i]=tempOne[i]+baisOne[i];
middle[i] = 1.0/(1.0+Math.exp(-tempOne[i]));
}
}
public void secondStep(double[][] weightTwo,double[] baisTwo){ //隐含层到输出层
/*w2 = new double[weightTwo.length][weightTwo[0].length];
for(int i = 0;i < w2.length;i ++){
for(int j = 0;j < w2[i].length;j ++){
w2[i][j] = weightTwo[i][j];
}
}*/
for(int i = 0;i < weightTwo.length;i ++){
tempTwo[i] = 0;borrow[i]=0;
for(int j = 0;j < weightTwo[i].length;j ++){
tempTwo[i] =tempTwo[i] + weightTwo[i][j]*middle[j];
}
tempTwo[i]=tempTwo[i]+baisTwo[i];
borrow[i] = 1.0/(1.0+Math.exp(-tempTwo[i]));
errorTwo[i] = (aim[i] - borrow[i])*(aim[i] - borrow[i]);
totalTwo = totalTwo+errorTwo[i];
}
//totalTwo = ((double)1/2*totalTwo);
}
public void threeStep(double[][] weightTwo){ //计算输出层误差效能
for(int i = 0;i < errorTwo.length;i ++){
two[i] = (aim[i] - borrow[i])*borrow[i]*(1-borrow[i]);
}
for(int j = 0;j < weightTwo[0].length;j ++){//列
for(int k = 0;k < weightTwo.length;k ++){//行
one[j]+= weightTwo[k][j]*two[k];
}
one[j] = one[j]*middle[j]*(1-middle[j]);
}
}
public void fourStep(double[][] weightTwo,double[] baisTwo){ //重新更新偏置矩阵
for(int i = 0;i < weightTwo.length;i ++){
for(int j = 0;j < weightTwo[i].length;j ++){
weightTwo[i][j] = weightTwo[i][j] + (BP.alpha*two[i]*middle[j]);
}
baisTwo[i] = baisTwo[i] + BP.alpha*two[i];
}
}
public void fiveStep(double[][] weightOne,double[] baisOne){ //重新更新偏置矩阵
for(int i = 0;i < weightOne.length;i ++){
for(int j = 0;j < weightOne[i].length;j ++){
weightOne[i][j] = weightOne[i][j] + (BP.alpha*one[i]*data[j]);
}
baisOne[i] = baisOne[i] + BP.alpha*one[i];
}
}
/*public double activate(double total){ //激活函数
return (1.0/(1.0+Math.exp(-total)));
}
public double activateDer(double total){ //激活函数的导数
return total*(1-total);
}*/
public void test(double[][] weight1,double[][] weight2,double[] bais1,double[] bais2){
for(int i = 0;i < weight1.length;i ++){
for(int j = 0;j < weight1[i].length;j ++){
tempOne[i] = tempOne[i] + weight1[i][j]*data[j];
}
tempOne[i]+=bais1[i];
middle[i] = 1.0/(1.0+Math.exp(-tempOne[i]));
}
for(int i = 0;i < weight2.length;i ++){
for(int j = 0;j < weight2[i].length;j ++){
tempTwo[i] = tempTwo[i] + weight2[i][j]*middle[j];
}
tempTwo[i] = tempTwo[i]+bais2[i];
answer[i] = 1.0/(1.0+Math.exp(-tempTwo[i]));
}
}
public void showAnswer(){
System.out.println("输入:");
for(int i = 0;i < data.length;i ++){
System.out.print(data[i]+" ");
}
System.out.println("\n应输出:");
for(int i = 0;i < aim.length;i ++){
System.out.print(aim[i]+" ");
}
System.out.println("\n实际输出:");
for(int i = 0;i < answer.length;i ++){
System.out.print(answer[i]+" ");
}
System.out.println("\n----------------------------------------------------------------------");
}
public double[] getAnswer() {
return answer;
}
public Input(int hide){
this.middle = new double[hide];
this.tempOne = new double[hide];
this.one = new double[hide];
}
public void show(){
for(int i = 0;i <data.length;i ++){
System.out.print(data[i]+" ");
}
System.out.print("XXX");
for(int i = 0;i < aim.length;i ++){
System.out.print(aim[i]+" ");
}
System.out.println("");
}
//------------------set和get方法----------------------------------------------------
public double[] getData() {
return data;
}
public void setData(double[] data) {
this.data = data;
}
public double[] getMiddle() {
return middle;
}
public void setMiddle(double[] middle) {
this.middle = middle;
}
public double[] getTempOne() {
return tempOne;
}
public void setTempOne(double[] tempOne) {
this.tempOne = tempOne;
}
public double[] getTempTwo() {
return tempTwo;
}
public void setTempTwo(double[] tempTwo) {
this.tempTwo = tempTwo;
}
public double[] getAim() {
return aim;
}
public void setAim(double[] aim) {
this.aim = aim;
this.borrow = new double[aim.length];
this.answer = new double[aim.length];
this.errorTwo = new double[aim.length];
this.tempTwo = new double[aim.length];
this.two = new double[aim.length];
}
public double[] getErrorTwo() {
return errorTwo;
}
public void setErrorTwo(double[] errorTwo) {
this.errorTwo = errorTwo;
}
public double getTotalTwo() {
return totalTwo;
}
public void setTotalTwo(double totalTwo) {
this.totalTwo = totalTwo;
}
public boolean isRead() {
return read;
}
public void setRead(boolean read) {
this.read = read;
}
}
这里参考了一下百度百科: 点击打开链接