人工智能之BP神经网络

什么是BP神经网络,BP网络是一种按照误差逆向传播算法训练的多层前馈神经网络。基本的BP神经网路包括信号的前向传播和误差的反向传播两个过程。这里我们具体研究的是简单的三层BP网络。这三层分别是,输入层、隐含层和输出层。如下图,输入层有输入X共n个输入,隐含层的神经元个数由自己来定义,为S1个,输出层即是目标输出共S2个。


如上图所示,我们可以得知,输入层到隐含层的加权矩阵为:S1*n,偏置矩阵为S1*1;隐含层到输出层的加权矩阵为S2*S1,偏置矩阵为S2*1。

首先,我们先来定义误差函数

这是一个样本的误差,一组样本的误差就是所有误差之和去均值即可

相比与感知器神经元,BP网络在权值变化时引入了一个误差效能δk


输入层的权值变化:


隐含层的权值变化:




这里,公式中有一个函数f,我们称之为激活函数,激活函数可以选用:


我这里使用的是logsig函数。

具体的算法实现过程:

1.  初始化权矩阵W1、W2,阈值向量B1、B2;

2.  初置精度控制参数e,学习率a,精度控制变量d= e+1;(t=0,T迭代次数)

3.  While d³ e do

4.  d=0;

5.  for 每个样本(X,Y) do

6.  输入X,计算隐含层输出A;

7.  隐含层输出A作为输出层的输入,计算输出层的输出O(即模型输出);

8.  计算累积误差:d=d+(yi-oi)

9.  根据输出层的误差效能计算隐含层的误差;

10.          根据输出层的误差效能修正W2、B2;

11.          根据隐含层的误差效能修正W1、B1(t++)观察t最后的值是根据t还是精度退出的,如果是根据精度退出则收敛,如果根据T则可能修改算法

下面是具体实现:

主流程类BP():

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Scanner;

public class BP {
	private ArrayList<Input> inputList = new ArrayList<Input>();
	private double[][] weightOne,weightTwo;
	private double[] baisOne,baisTwo;
	private int hide = 10;
	private int inputLong,outLong;
	public static double alpha = 0.2;
	private double LIMIT = 0.0001;
	private int MAX = 1000000;
	public static void main(String[] args){
		new BP();
	}
	public BP(){
		initMethod();	//读取数据
		weightOne = new double[hide][inputLong];	//输入层到隐含层
		weightTwo = new double[outLong][hide];	//隐含层到输出层
		baisOne = new double[hide];		//输入层到隐含层偏置
		baisTwo = new double[outLong];	//隐含层到输出层偏置
		for(int i = 0;i < weightOne.length;i ++){	//加权阵初始化
			baisOne[i] = Math.random()*2-1;
			for(int j = 0;j < weightOne[i].length;j ++){
				weightOne[i][j] = Math.random()*2-1;
			}
		}
		for(int i = 0;i < weightTwo.length;i ++){	//加权阵初始化
			baisTwo[i] = Math.random()*2-1;
			for(int j = 0;j < weightTwo[i].length;j ++){
				weightTwo[i][j] = Math.random()*2-1;
			}
		}
		for(int i = 0;i < inputList.size();i ++){
			//inputList.get(i).show();
			int temp = 0;
			temp = (int)(Math.random()*10);
			if(temp < 7){
				inputList.get(i).setRead(true);
			}else
				inputList.get(i).setRead(false);
		}
		doing();
	}
	private void doing() {
		// TODO Auto-generated method stub
		double sum = 0;int a = 0,t = 0;
		boolean test = true;
		int temp = 0;
		while(true){
			temp++;
			test = true;
			t=0;
			sum = 0;
			for(int i = 0;i < inputList.size();i ++){
				if(inputList.get(i).isRead()){
					inputList.get(i).firstStep(weightOne, baisOne);
					inputList.get(i).secondStep(weightTwo, baisTwo);
					inputList.get(i).threeStep(weightTwo);
					inputList.get(i).fourStep(weightTwo, baisTwo);
					inputList.get(i).fiveStep(weightOne, baisOne);
					sum += inputList.get(i).getTotalTwo();
					t++;
					/*System.out.print(sum);
					if(sum > LIMIT){
						test = false;
						System.out.print(" *");
					}
					System.out.println("");*/
				}
			}
			sum = 1.0/(2*t)*sum;
			if(sum < LIMIT){
				a = 1;
				for(int m = 0;m < inputList.size();m ++){
					//System.out.println(inputList.get(m).isRead());
					if(!inputList.get(m).isRead()){
						inputList.get(m).test(weightOne, weightTwo, baisOne, baisTwo);
						inputList.get(m).showAnswer();
					}
				}
				/*for(int n = 0;n <weightOne.length;n ++){
					for(int l = 0;l < weightOne[n].length;l ++)
					System.out.print(weightOne[n][l]+" ");
					System.out.println("");
				}
				System.out.println("+++++++++++++++++++");
				for(int n = 0;n <weightTwo.length;n ++){
					for(int l = 0;l < weightTwo[n].length;l ++)
					System.out.print(weightTwo[n][l]+" ");
					System.out.println("");
				}*/
				break;
			}
			if(temp > MAX) {break;}
			if(a == 1){
				break;
			}
		}
		if(temp > MAX)
		System.out.println("尴尬,超出次数了!");
	}
	private void initMethod() {
		// TODO Auto-generated method stub
		String testPath = "carb_x.txt";
		String resultPath = "carb_t.txt";
		try{
			File testFile = new File(testPath);
			File resultFile = new File(resultPath);
			if(testFile.isFile()&&testFile.exists()&&resultFile.isFile()&&resultFile.exists()){
				InputStreamReader readOne = new InputStreamReader(new FileInputStream(testFile),"gbk");
				InputStreamReader readTwo = new InputStreamReader(new FileInputStream(resultFile),"gbk");
				BufferedReader readerOne = new BufferedReader(readOne);
				BufferedReader readerTwo = new BufferedReader(readTwo);
				String lineOne,lineTwo;
				while((lineOne = readerOne.readLine())!=null){
					String[] strOne;
					strOne = lineOne.split("\t");
					Input input = new Input(hide);
					input.setData(change(strOne));
					inputList.add(input);
					inputLong = strOne.length;
				}
				int m = 0;
				while((lineTwo = readerTwo.readLine())!=null){
					String[] strTwo;
					strTwo = lineTwo.split("\t");
					inputList.get(m++).setAim(change(strTwo));
					outLong = strTwo.length;
				}
			
			}
		}catch(Exception e){
			e.printStackTrace();
		}
	}
	private double[] change(String[] strOne) {
		// TODO Auto-generated method stub
		double[] str;
		str = new double[strOne.length];
		for(int i = 0;i < strOne.length;i ++){
			str[i] = Double.parseDouble(strOne[i]);
		}
		return str;
	}
}

输入类型Input():

public class Input {
	private double[] data;//输入
	private double[] middle;//输入层到隐含层
	private double[] borrow;//隐含层到输出层
	private double[] tempOne;//未激活的隐含层结果
	private double[] tempTwo;//未激活的输出层结果
	private double[] aim;//目标
	private double[] errorTwo;
	private double totalTwo;
	private double[] answer;
	private double[] one,two;	//误差效能
	private boolean read = false;//是否用于测试结果,false表示还未用来学习,true表示
								//这个输入用来学习,不能用于测试了
	//private double[][] w1,w2;
	private void clean(){
		for(int i = 0;i < tempOne.length;i ++){
			tempOne[i] = 0;
			middle[i] = 0;
		}
		for(int i = 0;i < tempTwo.length;i ++){
			tempTwo[i] = 0;
			errorTwo[i] = 0;
			borrow[i] = 0;
		}
		totalTwo = 0;
		for(int i = 0;i < one.length;i ++){
			one[i] = 0;
		}
		for(int i = 0;i < errorTwo.length;i ++){
			two[i] = 0;
		}
		
	}
	
	public void firstStep(double[][] weightOne,double[] baisOne){		//输入层到隐含层
		clean();
		/*w1 = new double[weightOne.length][weightOne[0].length];
		for(int i = 0;i < w1.length;i ++){
			for(int j = 0;j < w1[i].length;j ++){
				w1[i][j] = weightOne[i][j];
			}
		}*/
		for(int i = 0;i < weightOne.length;i ++){
			tempOne[i] = 0;
			for(int j = 0;j < weightOne[i].length;j ++){
				tempOne[i] = tempOne[i] + weightOne[i][j]*data[j];
			}
			tempOne[i]=tempOne[i]+baisOne[i];
			middle[i] = 1.0/(1.0+Math.exp(-tempOne[i]));
		}
	}
	
	public void secondStep(double[][] weightTwo,double[] baisTwo){	//隐含层到输出层
		/*w2 = new double[weightTwo.length][weightTwo[0].length];
		for(int i = 0;i < w2.length;i ++){
			for(int j = 0;j < w2[i].length;j ++){
				w2[i][j] = weightTwo[i][j];
			}
		}*/
		for(int i = 0;i < weightTwo.length;i ++){
			tempTwo[i] = 0;borrow[i]=0;
			for(int j = 0;j < weightTwo[i].length;j ++){
				tempTwo[i] =tempTwo[i] + weightTwo[i][j]*middle[j];
			}
			tempTwo[i]=tempTwo[i]+baisTwo[i];
			borrow[i] = 1.0/(1.0+Math.exp(-tempTwo[i]));
			errorTwo[i] = (aim[i] - borrow[i])*(aim[i] - borrow[i]);
			totalTwo = totalTwo+errorTwo[i];
		}
		//totalTwo = ((double)1/2*totalTwo);
	}

	public void threeStep(double[][] weightTwo){	//计算输出层误差效能
		for(int i = 0;i < errorTwo.length;i ++){
			two[i] = (aim[i] - borrow[i])*borrow[i]*(1-borrow[i]);
		}
		for(int j = 0;j < weightTwo[0].length;j ++){//列
			for(int k = 0;k < weightTwo.length;k ++){//行
				one[j]+= weightTwo[k][j]*two[k];
			}
			one[j] = one[j]*middle[j]*(1-middle[j]);
		}
	}
	
	public void fourStep(double[][] weightTwo,double[] baisTwo){		//重新更新偏置矩阵
		for(int i = 0;i < weightTwo.length;i ++){
			for(int j = 0;j < weightTwo[i].length;j ++){
				weightTwo[i][j] = weightTwo[i][j] + (BP.alpha*two[i]*middle[j]);
			}
			baisTwo[i] = baisTwo[i] + BP.alpha*two[i];
		}
	}
	public void fiveStep(double[][] weightOne,double[] baisOne){		//重新更新偏置矩阵
		for(int i = 0;i < weightOne.length;i ++){
			for(int j = 0;j < weightOne[i].length;j ++){
				weightOne[i][j] = weightOne[i][j] + (BP.alpha*one[i]*data[j]);
			}
			baisOne[i] = baisOne[i] + BP.alpha*one[i];
		}
	}
	
	/*public double activate(double total){		//激活函数
		return (1.0/(1.0+Math.exp(-total)));
	}
	
	public double activateDer(double total){	//激活函数的导数
		return total*(1-total);
	}*/
	
	public void test(double[][] weight1,double[][] weight2,double[] bais1,double[] bais2){
		for(int i = 0;i < weight1.length;i ++){
			for(int j = 0;j < weight1[i].length;j ++){
				tempOne[i] = tempOne[i] + weight1[i][j]*data[j];
			}
			tempOne[i]+=bais1[i];
			middle[i] = 1.0/(1.0+Math.exp(-tempOne[i]));
		}
		for(int i = 0;i < weight2.length;i ++){
			for(int j = 0;j < weight2[i].length;j ++){
				tempTwo[i] = tempTwo[i] + weight2[i][j]*middle[j];
			}
			tempTwo[i] = tempTwo[i]+bais2[i];
			answer[i] = 1.0/(1.0+Math.exp(-tempTwo[i]));
		}
	}
	public void showAnswer(){
		System.out.println("输入:");
		for(int i = 0;i < data.length;i ++){
			System.out.print(data[i]+"  ");
		}
		System.out.println("\n应输出:");
		for(int i = 0;i < aim.length;i ++){
			System.out.print(aim[i]+"  ");
		}
		System.out.println("\n实际输出:");
		for(int i = 0;i < answer.length;i ++){
			System.out.print(answer[i]+"  ");
		}
		System.out.println("\n----------------------------------------------------------------------");
	}
	
	public double[] getAnswer() {
		return answer;
	}
	
	public Input(int hide){
		this.middle = new double[hide];
		this.tempOne = new double[hide];
		this.one = new double[hide];
	}
	public void show(){
		for(int i = 0;i <data.length;i ++){
			System.out.print(data[i]+"  ");
		}
		System.out.print("XXX");
		for(int i = 0;i < aim.length;i ++){
			System.out.print(aim[i]+"  ");
		}
		System.out.println("");
	}
	//------------------set和get方法----------------------------------------------------
	public double[] getData() {
		return data;
	}
	public void setData(double[] data) {
		this.data = data;
	}
	public double[] getMiddle() {
		return middle;
	}
	public void setMiddle(double[] middle) {
		this.middle = middle;
	}
	
	public double[] getTempOne() {
		return tempOne;
	}

	public void setTempOne(double[] tempOne) {
		this.tempOne = tempOne;
	}

	public double[] getTempTwo() {
		return tempTwo;
	}

	public void setTempTwo(double[] tempTwo) {
		this.tempTwo = tempTwo;
	}

	public double[] getAim() {
		return aim;
	}
	public void setAim(double[] aim) {
		this.aim = aim;
		this.borrow = new double[aim.length];
		this.answer = new double[aim.length];
		this.errorTwo = new double[aim.length];
		this.tempTwo = new double[aim.length];
		this.two = new double[aim.length];
	}
	public double[] getErrorTwo() {
		return errorTwo;
	}
	public void setErrorTwo(double[] errorTwo) {
		this.errorTwo = errorTwo;
	}
	public double getTotalTwo() {
		return totalTwo;
	}
	public void setTotalTwo(double totalTwo) {
		this.totalTwo = totalTwo;
	}
	public boolean isRead() {
		return read;
	}
	public void setRead(boolean read) {
		this.read = read;
	}
}

这里参考了一下百度百科: 点击打开链接

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值