weka之ID3

@Override
    public void buildClassifier(Instances data) throws Exception 
    {
        //检验算法能否直接处理数据
        getCapabilities().testWithFail(data);

        //删除带有缺失class标记的数据
        data=new Instances(data);
        data.deleteWithMissingClass();

        makeTree(data);
    }
private void makeTree(Instances data) throws Exception 
    {
        //如果没有是数据集到达这个节点,返回叶子节点
        if(data.numInstances()==0)
        {
            //当前分裂属性为空
            m_Attribute=null;
            //当前类别属性为空
            m_ClassAttribute=null;
            //分类比例
            m_Distribution=new double[data.numClasses()];

            return;
        }
        //保存每个属性的信息增益
        double infoGains[]=new double[data.numAttributes()];
        Enumeration attEnum=data.enumerateAttributes();
        while (attEnum.hasMoreElements()) 
        {
            Attribute att = (Attribute) attEnum.nextElement();
            infoGains[att.index()]=computeInfoGain(data, att);
        }
        //选取InfoGain最大的属性作为分裂属性
        m_Attribute=data.attribute(Utils.maxIndex(infoGains));

        System.err.println("我要打印InfoGain了");
        for(int i=0;i<infoGains.length;i++)
        {
            System.err.print(infoGains[i]+" ");
        }
        System.err.println();


        //当前分裂属性信息增益为0,说明是叶子节点
        if(infoGains[m_Attribute.index()]==0)
        {
            m_Attribute=null;
            m_Distribution=new double[data.numClasses()];

            //下面是投票表决
            Enumeration instEnum=data.enumerateInstances();
            while (instEnum.hasMoreElements()) 
            {
                Instance instance = (Instance) instEnum.nextElement();
                //类别取值下表对应的++
                m_Distribution[(int)instance.classValue()]++;
            }
            //归一化成0-1之间
            Utils.normalize(m_Distribution);
            //类别取值
            m_ClassValue=Utils.maxIndex(m_Distribution);
            m_ClassAttribute=data.classAttribute();
        }
        else//对数据进行划分
        {
            Instances[] splitData=splitData(data, m_Attribute);
            m_Successors=new Id3[m_Attribute.numValues()];
            for(int i=0;i<m_Attribute.numValues();i++)
            {
                m_Successors[i]=new Id3();
                m_Successors[i].makeTree(splitData[i]);
            }
        }
    }

    /*
     * 计算信息增益
     */
    private double computeInfoGain(Instances data, Attribute att) throws Exception 
    {
        //先计算整体的信息熵
        double infoGain=computeEntropy(data);
        //打印熵,用于调试
        System.err.println("我要打印熵1了");
        System.err.println(infoGain);
        //根据属性划分数据集
        Instances[] splitData=splitData(data, att);
        System.err.println("下面打印出划分好的数据集");
        for (int i = 0; i < splitData.length; i++) 
        {
            System.out.println(splitData[i].numInstances());
        }
        for (int i = 0; i < splitData.length; i++) 
        {
            if(splitData[i].numInstances()>0)
            {
                double temp1=((double)splitData[i].numInstances()/
                        (double)data.numInstances());
                double tempEntropy=computeEntropy(splitData[i]);
                double temp2=temp1*tempEntropy;
                infoGain-=temp1*temp2;
                System.err.println(infoGain);
            }
        }
        return infoGain;
    }

    //计算熵的函数
    private double computeEntropy(Instances data) 
    {
        double[] classCounts=new double[data.numClasses()];
        Enumeration instEnum=data.enumerateInstances();
        while (instEnum.hasMoreElements()) 
        {
            Instance inst = (Instance) instEnum.nextElement();
            classCounts[(int)inst.classValue()]++;
        }
        double entropy=0;
        double numInstances=data.numInstances();
        for (int i = 0; i < data.numClasses(); i++) 
        {
            if(classCounts[i]>0)
            {
                entropy-=((double)classCounts[i]/(double)numInstances)*Utils.log2((double)classCounts[i]/(double)numInstances);
            }
        }
        return entropy;
    }

    private Instances[] splitData(Instances data, Attribute att) 
    {
        //创建子数据集数组
        Instances[] splitData=new Instances[att.numValues()];
        //为数组分配空间
        for (int i = 0; i < splitData.length; i++) 
        {
            splitData[i]=new Instances(data,data.numInstances());
        }
        Enumeration instEnum=data.enumerateInstances();
        while (instEnum.hasMoreElements()) 
        {
            Instance inst = (Instance) instEnum.nextElement();
            splitData[(int)inst.value(att)].add(inst);
        }
        //处理空间,删除多余的没有用到的空间
        for (int i = 0; i < splitData.length; i++) 
        {
            splitData[i].compactify();
        }
        //返回划分好的数据集
        return splitData;
    }

    public double[] distrbutionForInstance(Instance instance) throws Exception
    {
        if(instance.hasMissingValue())
        {
            throw new Exception("Id3"+ "算法不能处理缺失值");
        }
        if(m_Attribute==null)
        {
            return m_Distribution;
        }
        else
        {
            return m_Successors[(int)instance.value(m_Attribute)].distrbutionForInstance(instance);
        }
    }

     public double classifyInstance(Instance instance) throws Exception 
     {
        if(instance.hasMissingValue())
        {
            throw new Exception("Id3"+ "算法不能处理缺失值");
        }
        if(m_Attribute==null)
        {
            return Utils.maxIndex(m_Distribution);
        }
        else
        {
            return m_Successors[(int)instance.value(m_Attribute)].classifyInstance(instance);
        }
     }
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值