数据挖掘-BP算法实现

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;


public class BP {
	List<String> data_var=new ArrayList<String>();//存储输入的变量数据
	List<String> data_tag=new ArrayList<String>();//存储输入的类别数据
	String[] kinds;  //用户存储类别的变量,kinds[0]表示第0类对应的类别名称
    public float[][][] w=new float[50][50][50];//存储层与层连接边的权重,我w[i][j][k]表示第i个层隐藏层的w[][]权重值,具体:第i-1层的第j个节点与第i层的第k个节点连接边的权重
    public float[][] beta=new float[50][50];//存储节点的偏移量,beta[i][j]表示第i个隐藏层中,第j个节点的偏倚量.
	
	public BP() throws IOException{//函数作用:数据初始化,从文件读入输入的:变量数据、类别数据,对类别数据类别数量划分一条数据的列数.
		for(int i=0;i<50;i++){
			for(int j=0;j<50;j++){
				for(int k=0;k<50;k++){
					w[i][j][k]=(float)Math.random();
				}
				    beta[i][j]=(float)Math.random();
			}
		}
		BufferedReader br=new BufferedReader(new FileReader("F:/数据挖掘--算法实现/BP算法/input.txt"));
        String line="";
        while((line=br.readLine())!=null){
        	data_tag.add(line.split(" ",2)[0]);
        	data_var.add(line.split(" ",2)[1]);
     	}
        HashSet<String> set=new HashSet<String>();
        for(int i=0;i<data_tag.size();i++){
        	set.add(data_tag.get(i));
        }
        kinds=new String[set.size()];
        int i=0;
        Iterator<String> Iter=set.iterator();
        while(Iter.hasNext()){
        	this.kinds[i]=Iter.next();
        	i++;
        }
        for(int j=0;j<data_tag.size();j++){
        	String tmp="";
        	for(i=0;i<kinds.length;i++){
        		if(kinds[i].equals(data_tag.get(j))){
        			tmp=tmp+" "+"1";
        		}else{
        			tmp=tmp+" "+"0";
        		}
        	}
        	data_tag.set(j,tmp.trim());
        }
	}
	
	public String Input_layer(String input_str){//函数作用:输入层的输入输出值
		String Output_str=input_str;
		return Output_str;
	}
	
	public String Hidden_layer(String input_str,int node_num,float[][] w,float beta[]){//函数作用:隐藏层的输入输出值
		String Output_str="";
		String[] input_tmp=input_str.split(" ");
		float[] input=new float[input_tmp.length];
		float[] I=new float[node_num];
		float[] O=new float[node_num];
		for(int i=0;i<input_tmp.length;i++){
			input[i]=Float.parseFloat(input_tmp[i]);
		}
		for(int j=0;j<I.length;j++){
			I[j]=beta[j];
			for(int i=0;i<input.length;i++){
				I[j]=I[j]+input[i]*w[i][j];
			}
			O[j]=(float) (1/(Math.exp(-I[j])+1));
			Output_str=Output_str+" "+String.valueOf(O[j]);
		}
		return Output_str.trim();
	}
	
	public String Output_err(String output_str,String true_str){//函数作用:计算输出层的误差
		String err="";
		float tmp=0f;
		String[] output_tmp=output_str.split(" ");
		float[] output=new float[output_tmp.length];
		String[] true_tmp=true_str.split(" ");
		float[] true_=new float[true_tmp.length];
		for(int i=0;i<output_tmp.length;i++){
			output[i]=Float.parseFloat(output_tmp[i]);
			true_[i]=Float.parseFloat(true_tmp[i]);
			tmp=output[i]*(1-output[i])*(true_[i]-output[i]);
			err=err+" "+String.valueOf(tmp);
		}
		return err.trim();
	}
	
	public String Hidden_err(String output_str,String err_next_layer,float w[][]){//函数作用:计算隐藏层的误差
		String err="";
		float tmp=0f;
		String[] err_next_layer_tmp=err_next_layer.split(" ");
		float[] err_next_layer_=new float[err_next_layer_tmp.length];
		for(int k=0;k<err_next_layer_tmp.length;k++){
			err_next_layer_[k]=Float.parseFloat(err_next_layer_tmp[k]);
		}
		String[] output_tmp=output_str.split(" ");
		float[] output=new float[output_tmp.length];
		for(int j=0;j<output_tmp.length;j++){
			output[j]=Float.parseFloat(output_tmp[j]);
			tmp=0f;
			for(int k=0;k<err_next_layer_.length;k++){
				tmp=tmp+err_next_layer_[k]*w[j][k];
			}
			tmp=tmp*output[j]*(1-output[j]);
			err=err+" "+String.valueOf(tmp);
		}
		return err.trim();
	}
		
	private float Predict_rate(List<String> predict_tag, List<String> data_tag){//函数作用:准确率计算函数,计算预测模型的准确度
		int correct_count=0;
		for(int i=0;i<predict_tag.size();i++){
			String[] tmp_0=predict_tag.get(i).split(" ");
			String[] tmp_1=data_tag.get(i).split(" ");
			int max_index=0;
			float max_value=Float.parseFloat(tmp_0[0]);
			for(int index=1;index<tmp_0.length;index++){
				if(Float.parseFloat(tmp_0[index])>max_value)
				   {max_value=Float.parseFloat(tmp_0[index]);max_index=index;}
			}
			if(tmp_1[max_index].equals("1")){correct_count++;}
			System.out.println("第"+i+"行数据预测为:"+kinds[max_index]+"类");
		}
		return ((float)correct_count)/predict_tag.size();
	}

	private void updata_beta(float[] beta, String this_eorr, float l) {//函数作用:更新一层的偏倚量beta
		String[] this_eorr_tmp=this_eorr.split(" ");
		float[] this_eorr_=new float[this_eorr_tmp.length];
		for(int j=0;j<this_eorr_tmp.length;j++) {
			this_eorr_[j]=Float.parseFloat(this_eorr_tmp[j]);
			beta[j]=beta[j]+l*this_eorr_[j];
		}
	}
		
	private float updata_w(float[][] w, String last_output, String this_eorr,float l) {//函数作用:更新一层的边的权重w
		float max=0;
		String[] last_output_tmp=last_output.split(" ");
		float[] last_output_=new float[last_output_tmp.length];
		for(int i=0;i<last_output_tmp.length;i++) last_output_[i]=Float.parseFloat(last_output_tmp[i]);
		String[] this_eorr_tmp=this_eorr.split(" ");
		float[] this_eorr_=new float[this_eorr_tmp.length];
		for(int j=0;j<this_eorr_tmp.length;j++) this_eorr_[j]=Float.parseFloat(this_eorr_tmp[j]);
		for(int i=0;i<last_output_.length;i++){
			for(int j=0;j<this_eorr_.length;j++){
				w[i][j]=w[i][j]+l*this_eorr_[j]*last_output_[i];
				if(Math.abs(l*this_eorr_[j]*last_output_[i])>max) max=Math.abs(l*this_eorr_[j]*last_output_[i]);
			}
		}
		return max;
	}

	public List<String> BP_two_layer(int node){//函数作用:两个隐藏层的BP算法模型,node是第一个隐藏层的节点数量,第二隐藏层的节点数量等于类别的数量
		float n=1;
		String input;
		String output_hidden;
		String output;
		String output_eorr;
		String hidden_eorr;
		float eorr=0;
		while(1==1){
			eorr=0;
			for(int i=0;i<data_var.size();i++){
				input=Input_layer(data_var.get(i));
				output_hidden=Hidden_layer(input,node,w[0],beta[0]);
				output=Hidden_layer(output_hidden,kinds.length,w[1],beta[1]);
				output_eorr=Output_err(output,data_tag.get(i));
				hidden_eorr=Hidden_err(output_hidden,output_eorr,w[1]);
				eorr=updata_w(w[0],input,hidden_eorr,1/n);
				eorr=Math.max(eorr,updata_w(w[1],output_hidden,output_eorr,1/n));
				updata_beta(beta[0],hidden_eorr,1/n);
				updata_beta(beta[1],output_eorr,1/n);
			}
			n=n+1;		
	      	List<String> predict_tag=new ArrayList<String>();
	     	for(int i=0;i<data_var.size();i++){
			   input=Input_layer(data_var.get(i));
			   output_hidden=Hidden_layer(input,node,w[0],beta[0]);
			   output=Hidden_layer(output_hidden,kinds.length,w[1],beta[1]);
			   predict_tag.add(i,output);
		    }
	     	if(eorr<0.00001 | n>20000)
	     	  { System.out.println("准确率:"+Predict_rate(predict_tag,data_tag)+"  迭代次数:"+n+"  所有w[i][j]更新增量小于:"+eorr);
	     		return predict_tag;}
	     }
	}
	
	public static void main(String[] args) throws IOException {
		BP a=new BP();
		a.BP_two_layer(5);
	}

}


训练数据:

1 1.5 1.2 0.3
1 2.5 0.3 0.0
1 0.8 0.2 0.3
1 1.1 0.3 0.0
1 1.1 0.2 1.0
1 1.5 1.2 0.2
1 2.5 0.3 0.1
1 0.8 0.2 0.2
1 1.1 0.1 0.0
1 1.1 0.1 1.0
2 0.3 2.2 0.3
2 0.1 1.2 1.3
2 0.5 1.3 0.3
2 1.1 1.3 1.0
2 1.5 1.6 0.3
2 1.1 1.4 1.0
3 1.9 3.9 1.0
3 1.1 2.4 1.0
3 2.1 2.9 1.8
4 2.5 0.2 1.3
4 3.1 1.3 1.0
4 3.5 1.2 1.3
4 2.5 1.2 0.3
4 2.5 1.5 0.6
4 3.1 1.7 1.4
4 4.5 1.8 1.3
4 3.5 1.1 0.4
4 4.1 0.8 1.2
5 1.3 1.3 1.0
5 2.2 1.3 1.0
5 1.4 1.2 1.3
5 1.6 1.2 1.9
5 1.7 1.3 1.0
5 2.1 1.5 1.0
2 0.2 2.2 0.3
2 0.2 1.2 1.3
2 0.4 1.3 0.3
2 1.4 1.3 1.0
2 1.3 1.6 0.3
2 1.2 1.4 1.0
3 1.8 3.9 1.0
3 1.3 2.4 1.0
3 2.1 2.3 1.8
4 2.5 0.5 1.3
4 3.1 1.7 1.0
4 3.5 1.3 1.3
4 2.8 1.2 0.3
4 2.8 1.5 0.6
4 3.2 1.7 1.4
4 4.6 1.8 1.3
4 3.6 1.1 0.4
4 4.2 0.6 1.2
5 1.4 1.3 1.0
5 2.2 1.6 1.0
5 1.4 1.1 1.2
5 1.6 1.5 1.8
5 1.8 1.2 1.0
5 2.2 1.6 1.1
2 1.5 1.5 0.3
2 1.1 1.5 1.0
3 1.9 3.3 1.0
3 1.1 2.3 1.0
3 2.1 2.3 1.8
4 2.5 0.3 1.3
4 3.1 1.4 1.0
4 3.5 1.4 1.3
4 2.5 1.4 0.3
4 2.5 1.4 0.6
4 3.1 1.8 1.4
4 4.5 1.5 1.3
4 3.5 1.5 0.4
4 4.1 0.5 1.2
5 1.3 1.5 1.0
5 2.2 1.5 1.0
5 1.4 1.5 1.3
5 1.6 1.5 1.9
5 1.7 1.5 1.0
5 2.1 1.7 1.0
2 0.2 2.7 0.3
2 0.2 1.7 1.3
2 0.4 1.1 0.3
2 1.4 1.1 1.0
2 1.3 1.1 0.3
2 1.2 1.3 1.0
3 1.8 3.3 1.0
3 1.3 2.5 1.0
3 2.1 2.5 1.8
4 2.5 0.6 1.3
4 3.1 1.6 1.0
4 3.5 1.7 1.3
4 2.8 1.7 0.3
4 2.8 1.8 0.6
4 3.2 1.9 1.4
4 4.6 1.9 1.3
4 3.6 1.2 0.4
4 4.2 0.7 1.2
5 1.4 1.2 1.0
5 2.2 1.7 1.0
5 1.4 1.2 1.2
5 1.6 1.6 1.8

输出:

准确率:0.93  迭代次数:13201.0  所有w[i][j]更新增量小于:9.999344E-6

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值