@Override
public void buildClassifier(Instances data) throws Exception
{
//检验算法能否直接处理数据
getCapabilities().testWithFail(data);
//删除带有缺失class标记的数据
data=new Instances(data);
data.deleteWithMissingClass();
makeTree(data);
}
private void makeTree(Instances data) throws Exception
{
if(data.numInstances()==0)
{
m_Attribute=null;
m_ClassAttribute=null;
m_Distribution=new double[data.numClasses()];
return;
}
double infoGains[]=new double[data.numAttributes()];
Enumeration attEnum=data.enumerateAttributes();
while (attEnum.hasMoreElements())
{
Attribute att = (Attribute) attEnum.nextElement();
infoGains[att.index()]=computeInfoGain(data, att);
}
m_Attribute=data.attribute(Utils.maxIndex(infoGains));
System.err.println("我要打印InfoGain了");
for(int i=0;i<infoGains.length;i++)
{
System.err.print(infoGains[i]+" ");
}
System.err.println();
if(infoGains[m_Attribute.index()]==0)
{
m_Attribute=null;
m_Distribution=new double[data.numClasses()];
Enumeration instEnum=data.enumerateInstances();
while (instEnum.hasMoreElements())
{
Instance instance = (Instance) instEnum.nextElement();
m_Distribution[(int)instance.classValue()]++;
}
Utils.normalize(m_Distribution);
m_ClassValue=Utils.maxIndex(m_Distribution);
m_ClassAttribute=data.classAttribute();
}
else
{
Instances[] splitData=splitData(data, m_Attribute);
m_Successors=new Id3[m_Attribute.numValues()];
for(int i=0;i<m_Attribute.numValues();i++)
{
m_Successors[i]=new Id3();
m_Successors[i].makeTree(splitData[i]);
}
}
}
private double computeInfoGain(Instances data, Attribute att) throws Exception
{
double infoGain=computeEntropy(data);
System.err.println("我要打印熵1了");
System.err.println(infoGain);
Instances[] splitData=splitData(data, att);
System.err.println("下面打印出划分好的数据集");
for (int i = 0; i < splitData.length; i++)
{
System.out.println(splitData[i].numInstances());
}
for (int i = 0; i < splitData.length; i++)
{
if(splitData[i].numInstances()>0)
{
double temp1=((double)splitData[i].numInstances()/
(double)data.numInstances());
double tempEntropy=computeEntropy(splitData[i]);
double temp2=temp1*tempEntropy;
infoGain-=temp1*temp2;
System.err.println(infoGain);
}
}
return infoGain;
}
private double computeEntropy(Instances data)
{
double[] classCounts=new double[data.numClasses()];
Enumeration instEnum=data.enumerateInstances();
while (instEnum.hasMoreElements())
{
Instance inst = (Instance) instEnum.nextElement();
classCounts[(int)inst.classValue()]++;
}
double entropy=0;
double numInstances=data.numInstances();
for (int i = 0; i < data.numClasses(); i++)
{
if(classCounts[i]>0)
{
entropy-=((double)classCounts[i]/(double)numInstances)*Utils.log2((double)classCounts[i]/(double)numInstances);
}
}
return entropy;
}
private Instances[] splitData(Instances data, Attribute att)
{
Instances[] splitData=new Instances[att.numValues()];
for (int i = 0; i < splitData.length; i++)
{
splitData[i]=new Instances(data,data.numInstances());
}
Enumeration instEnum=data.enumerateInstances();
while (instEnum.hasMoreElements())
{
Instance inst = (Instance) instEnum.nextElement();
splitData[(int)inst.value(att)].add(inst);
}
for (int i = 0; i < splitData.length; i++)
{
splitData[i].compactify();
}
return splitData;
}
public double[] distrbutionForInstance(Instance instance) throws Exception
{
if(instance.hasMissingValue())
{
throw new Exception("Id3"+ "算法不能处理缺失值");
}
if(m_Attribute==null)
{
return m_Distribution;
}
else
{
return m_Successors[(int)instance.value(m_Attribute)].distrbutionForInstance(instance);
}
}
public double classifyInstance(Instance instance) throws Exception
{
if(instance.hasMissingValue())
{
throw new Exception("Id3"+ "算法不能处理缺失值");
}
if(m_Attribute==null)
{
return Utils.maxIndex(m_Distribution);
}
else
{
return m_Successors[(int)instance.value(m_Attribute)].classifyInstance(instance);
}
}