ID3算法Java实现

根据统计学习方法书本的步骤

    import java.util.List;
    //计算经验熵
    public class Entroy {
        public double HD(List<Object[]> list, Object[] category){
            int n = list.get(0).length;
            double[] p = new double[category.length];
            double HD = 0;
            for(int k = 0; k <category.length; k ++){
                double num = 0;
                for(int i = 0; i < list.size(); i ++){
                    if(list.get(i)[n-1] == category[k]){
                        num++;
                    }
                }
                p[k] = num/list.size();
            }
            for(int m = 0; m < category.length; m ++){
              HD = HD + (-1) * p[m] * (Math.log(p[m])/Math.log(2.0));
            }
            return HD;
        }
    }
    
    import java.util.List;
    //计算条件熵
    public class Conditition {
        public double GD(List<Object[]> list,List<Object[]> array,  int n){
            Object[] objects = array.get(n);
            double gd = 0;
            double[] p = new double[objects.length];
            for(int k = 0; k <objects.length; k ++) {
                double  num = 0;
                for (int i = 0; i < list.size(); i++) {
                    if (list.get(i)[n] == objects[k]) {
                        num++;
                    }
                }
                p[k] = num/list.size();
                double h = HH(list,objects[k],n,array.get(array.size()-1));
                gd = gd + p[k] * h;
            }
            return gd;
        }
        public double HH(List<Object[]> list, Object object, int n, Object[] category){
            int m = list.get(0).length;
            double HD = 0;
            double[] p = new double[category.length];
            for(int k = 0; k < category.length; k ++){
                double num = 0, nums = 0;
                for(int i = 0; i < list.size(); i ++){
                    if(list.get(i)[n] == object){
                        num ++;
                        if(list.get(i)[m-1] == category[k]){
                            nums ++;
                        }
                    }
                }
                p[k] = nums/num;
            }
            for(int j = 0; j < category.length; j ++) {
                if (p[j] == 0) {
                    HD = 0;
                } else {
                    HD = HD + (-1) * p[j] * (Math.log(p[j]) / Math.log(2.0));
                }
            }
            return HD;
        }
    }
    //输出选择的特征,返回该特征维度
    public class OutPut {
        public int output(double[] GD, String[] feature){
            int max = 0;
            for(int i = 1; i < GD.length; i ++){
                if(GD[i] > GD[max]){
                    max = i;
                }
            }
            System.out.println(feature[max]);
            return max;
        }
    }
    
    import java.util.ArrayList;
    import java.util.List;
    
    import static java.lang.Float.NaN;
    //ID3算法
   import java.util.ArrayList;
import java.util.List;

import static java.lang.Float.NaN;

public class ID3 {
    public void id3(List<Object[]> list,  String[] feature , List<Object[]> array){
        Entroy e = new Entroy();
        int len = list.get(0).length;
        Conditition c = new Conditition();
        Object[] category = array.get(len-1);
        double HD = e.HD(list,category);
        double[] GD = new double[array.size()-1];
        double[] HDA = new double[array.size()-1];
        for(int i =0; i <array.size()-1; i ++) {
            HDA[i] = c.GD(list, array,i );
            GD[i] = HD -  HDA[i];
        }
        OutPut outPut = new OutPut();
        int max = outPut.output(GD,feature);
        List<List<Object[]>>lists = new ArrayList<>();
        //将数据按照特征划分为不同区域,为下一步求熵值做准备
        for(int k = 0; k < array.get(max).length; k ++) {
            List<Object[]> l = new ArrayList<>();
            for (int i = 0; i < list.size(); i++) {
                if (list.get(i)[max] == array.get(max)[k]){
                    l.add(list.get(i));
                }
            }
            lists.add(l);
        }
        boolean flag = false;
        for(int i = 0; i <lists.size(); i ++){
            double[] GD1 = new double[array.size() - 1];
            double[] HDA1 = new double[array.size() - 1];
            for(int k = 0; k < lists.get(i).size(); k++){
              if(lists.get(i).get(k)[len-1] != lists.get(i).get(0)[len-1]) {
                  flag = true;
              }
            }
            if(flag){
                id3(lists.get(i),feature,array);
            }
        }
    }
}
//测试
import java.util.ArrayList;
import java.util.List;

import static java.lang.Float.NaN;

public class test {
    public static void main(String[] args) {
        String[] feature = {"年龄","有工作","有自己的房子","信贷情况","类别"};
        Object[] age ={"青年","中年","老年"};
        Object[] work = {'是','否'};
        Object[] house = {'是','否'};
        Object[] loan = {"一般",'好',"非常好"};
        Object[] category = {'是','否'};
        List<Object[]> array = new ArrayList<Object[]>();
        array.add(age);
        array.add(work);
        array.add(house);
        array.add(loan);
        array.add(category);
        Object[] o1 = {age[0],work[1],house[1],loan[0],category[1]};
        Object[] o2 = {age[0],work[1],house[1],loan[1],category[1]};
        Object[] o3 = {age[0],work[0],house[1],loan[1],category[0]};
        Object[] o4 = {age[0],work[0],house[0],loan[0],category[0]};
        Object[] o5 = {age[0],work[1],house[1],loan[0],category[1]};
        Object[] o6 = {age[1],work[1],house[1],loan[0],category[1]};
        Object[] o7 = {age[1],work[1],house[1],loan[1],category[1]};
        Object[] o8 = {age[1],work[0],house[0],loan[1],category[0]};
        Object[] o9 = {age[1],work[1],house[0],loan[2],category[0]};
        Object[] o10 = {age[1],work[1],house[0],loan[2],category[0]};
        Object[] o11 = {age[2],work[1],house[0],loan[2],category[0]};
        Object[] o12 = {age[2],work[1],house[0],loan[1],category[0]};
        Object[] o13 = {age[2],work[0],house[1],loan[1],category[0]};
        Object[] o14 = {age[2],work[0],house[1],loan[2],category[0]};
        Object[] o15 = {age[2],work[1],house[1],loan[0],category[1]};
        List<Object[]> list = new ArrayList<Object[]>();
        list.add(o1);
        list.add(o2);
        list.add(o3);
        list.add(o4);
        list.add(o5);
        list.add(o6);
        list.add(o7);
        list.add(o8);
        list.add(o9);
        list.add(o10);
        list.add(o11);
        list.add(o12);
        list.add(o13);
        list.add(o14);
        list.add(o15);
        ID3 id3 = new ID3();
        id3.id3(list,feature,array);

    }
}
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值