1.伪代码
- 将所有属性放入可选属性集A中;
- 若训练集D的所有样本属于同一类别k,则构建叶节点,节点的类标签分布为[第k个为1.0,其他为0.0],返回该叶节点;
- 若A为空,使用ClassDistribution获得类标签分布,以此构建叶节点,并返回叶节点;
- 否则,计算A中所有属性的信息增益,找出信息增益最大的属性Ag;
- 若Ag的信息增益小于阈值,则使用ClassDistribution获得类标签分布,以此构建叶节点,并返回叶节点;
- 否则,基于当前训练集构建节点,然后基于Ag的取值个数N将当前训练集进行划分成N个子集,对第i个非空子集,以A-Ag为可选属性集,递归1-5得到所有子树,返回该节点及其子树。
2.代码
package weka.classifiers.xwq;
import java.util.ArrayList;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
public class ID3 extends Classifier
{
/**
* 训练数据集
*/
private Instances m_Train = null;
/**
* 基于训练集构建的决策树
*/
private TreeNodes m_Tree = null;
/**
* 属性值个数
*/
private int m_NumAttributes = -1;
/**
* 可选属性的集合
*/
private ArrayList<Attribute> m_AttributesOptions = new ArrayList<>();
/**
* 类标签的个数
*/
private int m_NumClassValues = -1;
/**
* 阈值
*/
private double m_Threshold = 0.1;
@Override
public void buildClassifier(Instances data) throws Exception
{
// TODO Auto-generated method stub
m_Train = new Instances(data);
m_NumAttributes = m_Train.numAttributes();
for (int i = 0; i < m_NumAttributes; i++)
if (i!=m_Train.classIndex())
m_AttributesOptions.add(m_Train.att