import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
public class Cart {
String Var="";
public float Gini_compute(List<String> Target,String Split){//函数作用:计算给定属性划分的Gini指数值,其中Target为二维向量集合,第一维表示属性,第二维表示种类.
//格式Target:a1 c1 split:a1 a2 a3
List<String> Target1=new ArrayList<String> ();
List<String> Target2=new ArrayList<String> ();
String[] Split_set=Split.split(" ");
Iterator<String> Iter=Target.iterator();
while(Iter.hasNext()){
String tmp=Iter.next();
String[] tmp_set=tmp.split(" ");
int in_Split=0;
for(int i=0;i<Split_set.length;i++){
if(Split_set[i].equals(tmp_set[0])){in_Split=1;break;}
}
if(in_Split==1){Target1.add(tmp);}else{Target2.add(tmp);}
}
float Gini=0;
Gini=Gini_index(Target1)*((float)Target1.size())/(Target1.size()+Target2.size());
Gini +=Gini_index(Target2)*((float)Target2.size())/(Target1.size()+Target2.size());
Gini=Gini_index(Target)-Gini;
return Gini;
}
public float Gini_index(List<String> Target){//函数作用:计算给集合的Gini指标计算.
String[] Terget_array=new String[Target.size()];
Set<String> Target_set=new HashSet<String>();
Iterator<String> Iter=Target.iterator();
int i=0;
while(Iter.hasNext()){
Terget_array[i]=Iter.next().split(" ")[1];
Target_set.add(Terget_array[i]);
i=i+1;
}
int[] count=new int[Target_set.size()];
float[] p=new float[Target_set.size()];
Iterator<String> Iter1=Target_set.iterator();
i=0;
while(Iter1.hasNext()){
count[i]=0;
String tmp=Iter1.next();
for(int j=0;j<Terget_array.length;j++){
if(Terget_array[j].equals(tmp)){count[i] +=1;}
}
p[i]=(((float)count[i])/Terget_array.length)*(((float)count[i])/Terget_array.length);
i=i+1;
}
float sum=0;
for(i=0;i<p.length;i++){
sum=sum+p[i];
}
return 1-sum;
}
public List<String> Gini_select(List<String> DataSet,int i){//函数作用:计算DataSet中第i列指标的最优属性划分
List<String> DataSet_i=new ArrayList<String>();
Set<String> DataSet_i_set=new HashSet<String>();
Iterator<String> Iter=DataSet.iterator();
while(Iter.hasNext()){
String[] tmp=Iter.next().split(" ");
DataSet_i.add(tmp[i]+" "+tmp[tmp.length-1]);
DataSet_i_set.add(tmp[i]);
}
String set_i="";
Iterator<String> Iter1=DataSet_i_set.iterator();
while(Iter1.hasNext()){
set_i=set_i+" "+Iter1.next();
}
set_i=set_i.trim();
ArrayList<String> list = new ArrayList<String>();
doGetSubSequences(set_i,"",list);
String max_set=list.get(0);
float max=Gini_compute(DataSet_i,max_set);
for(int j=1;j<list.size();j++){
if(Gini_compute(DataSet_i,list.get(j))>max)
{max=Gini_compute(DataSet_i,list.get(j));max_set=list.get(j);}
}
List<String> return_list=new ArrayList<String>();
return_list.add(max_set);
return_list.add(String.valueOf(max));
return return_list;
}
private static void doGetSubSequences(String word, String s,ArrayList<String> list) {
if (word.length() == 0) {//函数作用:给定集合的所有子集
s=s.trim();
list.add(s);
return;
}
String tail="";
if(word.split(" ",2).length>=2)
{tail= word.split(" ",2)[1];}
doGetSubSequences(tail, s, list);
doGetSubSequences(tail, s + " "+word.split(" ",2)[0], list);
}
public void Cart_tree(List<String> DataSet,String path,int alpha,int alpha_max){
if(alpha==alpha_max | DataSet.size()<=2){//cart决策树,终止条件1
write_result(DataSet,path);
return;
}
int count_var=DataSet.get(0).split(" ").length-1;
String max_split_L="";
float max_Gini=-1;
int max_index=-1;
for(int i=0;i<count_var;i++){
if(Float.parseFloat(Gini_select(DataSet,i).get(1))>max_Gini){
max_Gini=Float.parseFloat(Gini_select(DataSet,i).get(1));
max_split_L=Gini_select(DataSet,i).get(0);
max_index=i;
}
}
if(max_Gini<=0.01){//cart决策树,终止条件2
write_result(DataSet,path);
return;
}
List<String> DataSet_L=new ArrayList<String>();
List<String> DataSet_R=new ArrayList<String>();
DataSet_split(DataSet,max_index,max_split_L,DataSet_L,DataSet_R);
String max_split_R=Compute_split_R(DataSet,max_index,max_split_L);
Cart_tree(DataSet_L,path+"|"+this.Var.split(" ")[max_index]+":"+max_split_L,alpha+1,alpha_max);
Cart_tree(DataSet_R,path+"|"+this.Var.split(" ")[max_index]+":"+max_split_R,alpha+1,alpha_max);
}
private void write_result(List<String> DataSet, String path) {//函数作用:输出cart叶子节点的结果
String[] Category=new String[DataSet.size()];
for(int i=0;i<Category.length;i++){
Category[i]=DataSet.get(i).trim().split(" ")[DataSet.get(i).trim().split(" ").length-1];
}
Map<String,Integer> map=new HashMap<String,Integer>();
for(int i=0;i<Category.length;i++){
if(!map.containsKey(Category[i])){
map.put(Category[i], 1);
}else{
map.put(Category[i], map.get(Category[i])+1);
}
}
int sum_count=0;
int max_count=0;
String max_Category="";
Iterator<String> Iter=map.keySet().iterator();
while(Iter.hasNext()){
String tmp=Iter.next();
if(map.get(tmp)>=max_count){
max_count=map.get(tmp);
max_Category=tmp;
}
sum_count=sum_count+map.get(tmp);
}
int count=DataSet.size();
String forcast=max_Category;
float accuracy_rate=((float)max_count)/sum_count;
System.out.println("Rule:"+path+". Count:"+count+". "+this.Var.split(" ")[this.Var.split(" ").length-1]+":"+forcast+". Accuracy_rate:"+accuracy_rate);
}
private String Compute_split_R(List<String> DataSet, int index,
String split_L) {//函数作用:DataSet中第index列中,属性一半划分为split_L,输出另外的一半划分split_R
String split_R="";
Set<String> set=new HashSet<String>();
for(int i=0;i<DataSet.size();i++){
set.add(DataSet.get(i).split(" ")[index]);
}
for(int i=0;i<split_L.trim().split(" ").length;i++){
set.remove(split_L.trim().split(" ")[i]);
}
Iterator<String> Iter=set.iterator();
while(Iter.hasNext()){
split_R=split_R+" "+Iter.next();
}
return split_R.trim();
}
private void DataSet_split(List<String> DataSet, int max_index,
String max_split_L, List<String> DataSet_L, List<String> DataSet_R) {
for(int i=0;i<DataSet.size();i++){//函数作用:DataSet第max_index列按照属性max_split_L划分后的两个数集为DataSet_L,DataSet_R.
int i_in_L=0;
for(int j=0;j<max_split_L.trim().split(" ").length;j++){
if(DataSet.get(i).split(" ")[max_index].equals(max_split_L.trim().split(" ")[j])){
DataSet_L.add(DataSet.get(i));
i_in_L=1;
break;
}
}
if(i_in_L==0){DataSet_R.add(DataSet.get(i));}
}
}
public static void main(String[] args) throws IOException {
BufferedReader br=new BufferedReader(new FileReader("F:/数据挖掘--算法实现/cart算法/input.txt"));
String line="";
int i=0;
List<String> DataSet=new ArrayList<String>();
String Var="";
while((line=br.readLine())!=null){
if(i==0){i=1;Var=line;continue;}
DataSet.add(line);
}
Cart a=new Cart();
a.Var=Var;
a.Cart_tree(DataSet,"",0,2);
}
}
输入:
age income student credit_rating buys_computer
youth high no fair no
youth high no excellent no
middle_aged high no fair yes
senior medium no fair yes
senior low yes fair yes
senior low yes excellent no
middle_aged low yes excellent yes
youth medium no fair no
youth low yes fair yes
senior medium yes fair yes
youth medium yes excellent yes
middle_aged medium no excellent yes
middle_aged high yes fair yes
senior medium no excellent no
数据格式说明:第一行表示变量名,其中buys_computer是目标变量,其余的行表示用户数据,每个数据单元以空格分开
输出结果:
Rule:|age:middle_aged. Count:4. buys_computer:yes. Accuracy_rate:1.0
Rule:|age:senior youth|student:yes. Count:5. buys_computer:yes. Accuracy_rate:0.8
Rule:|age:senior youth|student:no. Count:5. buys_computer:no. Accuracy_rate:0.8