数据挖掘--Cart算法的实现

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;




public class Cart {
	String Var="";
	public float Gini_compute(List<String> Target,String Split){//函数作用:计算给定属性划分的Gini指数值,其中Target为二维向量集合,第一维表示属性,第二维表示种类.
		//格式Target:a1 c1    split:a1 a2 a3
		List<String> Target1=new ArrayList<String> ();
		List<String> Target2=new ArrayList<String> ();
		String[] Split_set=Split.split(" ");
		Iterator<String> Iter=Target.iterator();
		while(Iter.hasNext()){
			String tmp=Iter.next();
			String[] tmp_set=tmp.split(" ");
			int in_Split=0;
			for(int i=0;i<Split_set.length;i++){
			    if(Split_set[i].equals(tmp_set[0])){in_Split=1;break;}
			}
			if(in_Split==1){Target1.add(tmp);}else{Target2.add(tmp);}
		}
		float Gini=0;
		Gini=Gini_index(Target1)*((float)Target1.size())/(Target1.size()+Target2.size());
		Gini +=Gini_index(Target2)*((float)Target2.size())/(Target1.size()+Target2.size());
		Gini=Gini_index(Target)-Gini;
		return Gini;
	}
	
	public float Gini_index(List<String> Target){//函数作用:计算给集合的Gini指标计算.
		String[] Terget_array=new String[Target.size()];
		Set<String> Target_set=new HashSet<String>();
		Iterator<String> Iter=Target.iterator();
		int i=0;
		while(Iter.hasNext()){
			Terget_array[i]=Iter.next().split(" ")[1];
			Target_set.add(Terget_array[i]);
			i=i+1;
		}
		int[] count=new int[Target_set.size()];
		float[] p=new float[Target_set.size()];
		Iterator<String> Iter1=Target_set.iterator();
		i=0;
		while(Iter1.hasNext()){
			count[i]=0;
			String tmp=Iter1.next();
			for(int j=0;j<Terget_array.length;j++){
				if(Terget_array[j].equals(tmp)){count[i] +=1;}
			}
			p[i]=(((float)count[i])/Terget_array.length)*(((float)count[i])/Terget_array.length);
			i=i+1;			
		}
		float sum=0;
		for(i=0;i<p.length;i++){
			sum=sum+p[i];
		}
		return 1-sum;
	}
	
	public List<String> Gini_select(List<String> DataSet,int i){//函数作用:计算DataSet中第i列指标的最优属性划分
		List<String> DataSet_i=new ArrayList<String>();
		Set<String> DataSet_i_set=new HashSet<String>();
		Iterator<String> Iter=DataSet.iterator();
		while(Iter.hasNext()){
			String[] tmp=Iter.next().split(" ");
			DataSet_i.add(tmp[i]+" "+tmp[tmp.length-1]);
			DataSet_i_set.add(tmp[i]);
		}
		String set_i="";
		Iterator<String> Iter1=DataSet_i_set.iterator();
		while(Iter1.hasNext()){
			set_i=set_i+" "+Iter1.next();
		}
		set_i=set_i.trim();
		ArrayList<String> list = new ArrayList<String>();
		doGetSubSequences(set_i,"",list);
		String max_set=list.get(0);
		float max=Gini_compute(DataSet_i,max_set);
		for(int j=1;j<list.size();j++){
			if(Gini_compute(DataSet_i,list.get(j))>max)
			{max=Gini_compute(DataSet_i,list.get(j));max_set=list.get(j);}
		}
		List<String> return_list=new ArrayList<String>();
		return_list.add(max_set);
		return_list.add(String.valueOf(max));
		return return_list;
	}
	
	private static void doGetSubSequences(String word, String s,ArrayList<String> list) {
		if (word.length() == 0) {//函数作用:给定集合的所有子集
			s=s.trim();
			list.add(s);
			return;
		}
		String tail="";
		if(word.split(" ",2).length>=2)
		{tail= word.split(" ",2)[1];}
		doGetSubSequences(tail, s, list);
		doGetSubSequences(tail, s + " "+word.split(" ",2)[0], list);
	}
	
	public void Cart_tree(List<String> DataSet,String path,int alpha,int alpha_max){
		if(alpha==alpha_max | DataSet.size()<=2){//cart决策树,终止条件1
			write_result(DataSet,path);
			return;
		}
		int count_var=DataSet.get(0).split(" ").length-1;
		String max_split_L="";
		float max_Gini=-1;
		int max_index=-1;
		for(int i=0;i<count_var;i++){
			if(Float.parseFloat(Gini_select(DataSet,i).get(1))>max_Gini){
				max_Gini=Float.parseFloat(Gini_select(DataSet,i).get(1));
				max_split_L=Gini_select(DataSet,i).get(0);
				max_index=i;
			}
		}
		if(max_Gini<=0.01){//cart决策树,终止条件2
			write_result(DataSet,path);
			return;
		}
		List<String> DataSet_L=new ArrayList<String>();
		List<String> DataSet_R=new ArrayList<String>();
		DataSet_split(DataSet,max_index,max_split_L,DataSet_L,DataSet_R);
		String max_split_R=Compute_split_R(DataSet,max_index,max_split_L);
		Cart_tree(DataSet_L,path+"|"+this.Var.split(" ")[max_index]+":"+max_split_L,alpha+1,alpha_max);
		Cart_tree(DataSet_R,path+"|"+this.Var.split(" ")[max_index]+":"+max_split_R,alpha+1,alpha_max);
	}
	
	private void write_result(List<String> DataSet, String path) {//函数作用:输出cart叶子节点的结果
		String[] Category=new String[DataSet.size()];
		for(int i=0;i<Category.length;i++){
			Category[i]=DataSet.get(i).trim().split(" ")[DataSet.get(i).trim().split(" ").length-1];
		}
		Map<String,Integer> map=new HashMap<String,Integer>();
		for(int i=0;i<Category.length;i++){
			if(!map.containsKey(Category[i])){
				map.put(Category[i], 1);
			}else{
				map.put(Category[i], map.get(Category[i])+1);
			}
		}
		int sum_count=0;
		int max_count=0;
		String max_Category="";
		Iterator<String> Iter=map.keySet().iterator();
		while(Iter.hasNext()){
			String tmp=Iter.next();
			if(map.get(tmp)>=max_count){
				max_count=map.get(tmp);
				max_Category=tmp;
			}
			sum_count=sum_count+map.get(tmp);
		}
		int count=DataSet.size();
	    String forcast=max_Category;
	    float accuracy_rate=((float)max_count)/sum_count;
	    System.out.println("Rule:"+path+".   Count:"+count+".   "+this.Var.split(" ")[this.Var.split(" ").length-1]+":"+forcast+".   Accuracy_rate:"+accuracy_rate);
	}

	private String Compute_split_R(List<String> DataSet, int index,
			String split_L) {//函数作用:DataSet中第index列中,属性一半划分为split_L,输出另外的一半划分split_R
		String split_R="";
		Set<String> set=new HashSet<String>();
		for(int i=0;i<DataSet.size();i++){
			set.add(DataSet.get(i).split(" ")[index]);
			}
		for(int i=0;i<split_L.trim().split(" ").length;i++){
			set.remove(split_L.trim().split(" ")[i]);
		}
		Iterator<String> Iter=set.iterator();
		while(Iter.hasNext()){
			split_R=split_R+" "+Iter.next();
		}
		return split_R.trim();
	}

	private void DataSet_split(List<String> DataSet, int max_index,
			String max_split_L, List<String> DataSet_L, List<String> DataSet_R) {
		for(int i=0;i<DataSet.size();i++){//函数作用:DataSet第max_index列按照属性max_split_L划分后的两个数集为DataSet_L,DataSet_R.
			int i_in_L=0;
			for(int j=0;j<max_split_L.trim().split(" ").length;j++){
				if(DataSet.get(i).split(" ")[max_index].equals(max_split_L.trim().split(" ")[j])){
					DataSet_L.add(DataSet.get(i));
					i_in_L=1;
					break;
				}
			}
			if(i_in_L==0){DataSet_R.add(DataSet.get(i));}
		}
	}

	public static void main(String[] args) throws IOException {
		BufferedReader br=new BufferedReader(new FileReader("F:/数据挖掘--算法实现/cart算法/input.txt"));  
        String line="";
        int i=0;
        List<String> DataSet=new ArrayList<String>();
        String Var="";
        while((line=br.readLine())!=null){
        	if(i==0){i=1;Var=line;continue;}
        	DataSet.add(line);
        }
        Cart a=new Cart();
        a.Var=Var;
		a.Cart_tree(DataSet,"",0,2);
	}

}


输入:

age income student credit_rating buys_computer
youth high no fair no
youth high no excellent no
middle_aged high no fair yes
senior medium no fair yes
senior low yes fair yes
senior low yes excellent no
middle_aged low yes excellent yes
youth medium no fair no
youth low yes fair yes
senior medium yes fair yes
youth medium yes excellent yes
middle_aged medium no excellent yes
middle_aged high yes fair yes
senior medium no excellent no

数据格式说明:第一行表示变量名,其中buys_computer是目标变量,其余的行表示用户数据,每个数据单元以空格分开


输出结果:

Rule:|age:middle_aged.   Count:4.   buys_computer:yes.   Accuracy_rate:1.0
Rule:|age:senior youth|student:yes.   Count:5.   buys_computer:yes.   Accuracy_rate:0.8
Rule:|age:senior youth|student:no.   Count:5.   buys_computer:no.   Accuracy_rate:0.8

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值