机器学习——决策树

学习来源:日撸 Java 三百行(61-70天,决策树与集成学习)

  1. 决策是人类天天干的事情,如中午吃啥,股票买啥。
  2. 决策树是为决策构建的树,决策树的核心是确定当前数据使用哪个属性来分割。不同的算法可能使用不同的属性。
  3. 条件熵计算公式:
    在这里插入图片描述
    数据集:
@relation weather
@attribute Outlook {Sunny, Overcast, Rain}
@attribute Temperature {Hot, Mild, Cool}
@attribute Humidity {High, Normal, Low}
@attribute Windy {FALSE, TRUE}
@attribute Play {N, P}
@data
Sunny,Hot,High,FALSE,N
Sunny,Hot,High,TRUE,N
Overcast,Hot,High,FALSE,P
Rain,Mild,High,FALSE,P
Rain,Cool,Normal,FALSE,P
Rain,Cool,Normal,TRUE,N
Overcast,Cool,Normal,TRUE,P
Sunny,Mild,High,FALSE,N
Sunny,Cool,Normal,FALSE,P
Rain,Mild,Normal,FALSE,P
Sunny,Mild,Normal,TRUE,P
Overcast,Mild,High,TRUE,P
Overcast,Hot,Normal,FALSE,P
Rain,Mild,High,TRUE,N

代码:

package machinelearning.decisiontree;

import weka.core.Instance;
import weka.core.Instances;

import java.io.FileReader;
import java.util.Arrays;

public class ID3 {

    /**
     * The data.
     */
    Instances dataset;

    /**
     * Is this dataset pure (only one label)?
     */
    boolean pure;

    /**
     * The number of classes. For binary classification it is 2.
     */
    int numClasses;

    /**
     * Available instances. Other instances don't belong this branch.
     */
    int[] availableInstances;

    /**
     * Available attributes. Other attributes have been selected in the path
     * from the root
     */
    int[] availableAttributes;

    /**
     * The selected attribute
     */
    int splitAttribute;

    /**
     * The children nodes.
     */

    ID3[] children;

    /**
     * My label. Inner nodes also have a label. For example, <outlook = sunny,
     * humidity = high> never appear in the training data, but <humidity = high>
     * is valid in other cases.
     */
    int label;

    /**
     * The prediction,including queried and predicted labels.
     *
     */
    int[] predicts;

    /**
     * Small block cannot be split further.
     */
    static int smallBlockThreshold = 3;


    public ID3(String paraFilename) {
        dataset = null;
        try {
            FileReader fileReader = new FileReader(paraFilename);
            dataset = new Instances(fileReader);
            fileReader.close();
        } catch (Exception ee) {
            System.out.println("Cannot read the file: " + paraFilename + "\r\n" + ee);
            System.exit(0);
        }// of try
        
        dataset.setClassIndex(dataset.numAttributes() - 1);
        numClasses = dataset.classAttribute().numValues();
        
        availableInstances = new int[dataset.numInstances()];
        for (int i = 0; i < availableInstances.length; i++) {
            availableInstances[i] = i;
        }// of for i
        
        availableAttributes = new int[dataset.numAttributes() - 1];
        for (int i = 0; i < availableAttributes.length; i++) {
            availableAttributes[i] = i;
        }// of for i
        
        //Initialize.
        children = null;
        
        //Determine the label by simple voting.
        label = getMajorityClass(availableInstances);

        // Determine whether or not it is pure.
        pure  = pureJudge(availableInstances);
    }// of the first constructor


    public ID3(Instances paraDataset,int[] paraAvailableInstances, int[] paraAvailableAttributes) {
        //Copy its reference instead of clone the availableInstances.

        dataset = paraDataset;
        availableInstances = paraAvailableInstances;
        availableAttributes = paraAvailableAttributes;

        // Initialize.
        children = null;
        // Determine the label by simple voting.
        label = getMajorityClass(availableInstances);
        // Determine whether or not it is pure.
        pure = pureJudge(availableInstances);
    }

    /**
     * Compute the majority class of the given block for voting.
     * @param paraBlock the block
     * @return The majority class
     */
    private int getMajorityClass(int[] paraBlock) {
        int[] tempClassCounts = new int[dataset.numClasses()];
        for (int i = 0; i < paraBlock.length; i++) {
            tempClassCounts[(int)dataset.instance(paraBlock[i]).classValue()]++;
        }

        int resultMajorityClass = -1;
        int tempMaxCount = -1;
        for (int i = 0; i < tempClassCounts.length; i++) {
            if (tempMaxCount < tempClassCounts[i]) {
                tempMaxCount = tempClassCounts[i];
                resultMajorityClass = i;
            }// of if
        }// of for i
        return resultMajorityClass;
    }// of getMajorityClass

    /**
     * Is the given block pure?
     * @param paraBlock the block.
     * @return True if pure.
     */
    private boolean pureJudge(int[] paraBlock) {
        pure = true;

        for (int i = 0; i < paraBlock.length; i++) {
            if (dataset.instance(paraBlock[i]).classValue() != dataset.instance(paraBlock[0]).classValue()) {
                pure = false;
                break;
            }// of if
        }// of for i

        return pure;
    }// of pureJudge

    /**
     * Select the best attribute
     * @return The best attribute index.
     */
    public int selectBestAttribute() {
        splitAttribute = -1;
        double tempMinimalEntropy = 10000;
        double tempEntropy;
        for (int i = 0; i < availableAttributes.length; i++) {
            tempEntropy = conditionalEntropy(availableAttributes[i]);
            if (tempMinimalEntropy > tempEntropy) {
                tempMinimalEntropy = tempEntropy;
                splitAttribute = availableAttributes[i];
            }// of if
        }// of for i
        return splitAttribute;
    }//of selectBestAttribute

    /**
     * Compute the conditional entropy of an attribute.
     * @param paraAttribute the given attribute.
     * @return The entropy
     */
    private double conditionalEntropy(int paraAttribute) {
        //Step 1. Statistic
        int tempNumClasses = dataset.numClasses();
        int tempNumValues = dataset.attribute(paraAttribute).numValues();
        int tempNumInstances = availableInstances.length;
        double[] tempValueCounts = new double[tempNumValues];
        double[][] tempCountMatrix = new double[tempNumValues][tempNumClasses];

        int tempClass, tempValue;
        for (int i = 0; i < tempNumInstances; i++) {
            tempClass = (int)dataset.instance(availableInstances[i]).classValue();
            tempValue = (int)dataset.instance(availableInstances[i]).value(paraAttribute);
            tempValueCounts[tempValue]++;
            tempCountMatrix[tempValue][tempClass]++;
        }// of for i

        //Step 2
        double resultEntropy = 0;
        double tempEntropy, tempFraction;
        for (int i = 0; i < tempNumValues; i++) {
            if(tempValueCounts[i] == 0) {
                continue;
            }// of if
            tempEntropy = 0;
            for (int j = 0; j < tempNumClasses; j++) {
                tempFraction = tempCountMatrix[i][j] / tempValueCounts[i];
                if(tempFraction == 0) {
                    continue;
                }
                tempEntropy += -tempFraction * Math.log(tempFraction);
            }// of for j

            resultEntropy += tempValueCounts[i] / tempNumInstances * tempEntropy;
        }// of for i

        return resultEntropy;
    }// of conditionalEntropy

    /**
     * Split the data according to the given attribute.
     * @return The blocks.
     */
    public int[][] splitData(int paraAttribute) {
        int tempNumValues = dataset.attribute(paraAttribute).numValues();
        int[][] resultBlocks = new int[tempNumValues][];
        int[] tempSizes = new int[tempNumValues];

        //First scan to count the size of each block.
        int tempValue;
        for (int i = 0; i < availableInstances.length; i++) {
            tempValue = (int)dataset.instance(availableInstances[i]).value(paraAttribute);
            tempSizes[tempValue]++;
        }// of for i

        //Allocate space
        for (int i = 0; i < tempNumValues; i++) {
            resultBlocks[i] = new int[tempSizes[i]];
        }// of for i

        //Second scan to fill.
        Arrays.fill(tempSizes,0);
        for (int i = 0; i < availableInstances.length; i++) {
            tempValue = (int)dataset.instance(availableInstances[i]).value(paraAttribute);
            //Copy data.
            resultBlocks[tempValue][tempSizes[tempValue]] = availableInstances[i];
            tempSizes[tempValue]++;
        }// of for i
        return resultBlocks;
    }// of splitData

    /**
     * Build the tree recursively
     */
    public void buildTree() {
        if (pureJudge(availableInstances)) {
            return;
        }// of if
        if (availableInstances.length <= smallBlockThreshold) {
            return;
        }// of if

        selectBestAttribute();
        int[][] tempSubBlocks = splitData(splitAttribute);
        children = new ID3[tempSubBlocks.length];

        //Construct the remaining attribute set.
        int[] tempRemainingAttributes = new int[availableAttributes.length - 1];
        for (int i = 0; i < availableAttributes.length; i++) {
            if (availableAttributes[i] < splitAttribute) {
                tempRemainingAttributes[i] = availableAttributes[i];
            }else if (availableAttributes[i] > splitAttribute){
                tempRemainingAttributes[i - 1] = availableAttributes[i];
            }// of if

        }// of for i

        // Construct children.
        for (int i = 0; i < children.length; i++) {
            if ((tempSubBlocks[i]) == null || (tempSubBlocks[i].length == 0)) {
                children[i] = null;
                continue;
            } else {
                children[i] = new ID3(dataset,tempSubBlocks[i],tempRemainingAttributes);

                // Important code: do this recursively
                children[i].buildTree();
            }// of if
        } // of for i
    }//of buildTree

    /**
     * Classify an instance
     * @param paraInstance
     * @return
     */
    public int classify(Instance paraInstance) {
        if (children == null) {
            return label;
        }// of if

        ID3 tempChild = children[(int) paraInstance.value(splitAttribute)];
        if (tempChild == null) {
            return label;
        }// of if
        return tempChild.classify(paraInstance);
    }

    /**
     * Test on a testing set.
     * @param paraDataset The given testing data.
     * @return The accuracy.
     */
    public double test(Instances paraDataset) {
        double tempCorrect = 0;
        for (int i = 0; i < paraDataset.numInstances(); i++) {
            if (classify(paraDataset.instance(i)) == (int)paraDataset.instance(i).classValue()) {
                tempCorrect ++;
            }// of i
        }// of for i

        return tempCorrect / paraDataset.numInstances();
    }// of test

    /**
     * Test on the training set
     * @return the accuracy
     */
    public double selfTest() {
        return test(dataset);
    }// of selfTest


    public String toString() {
        String resultString = "";
        String tempAttributeName = dataset.attribute(splitAttribute).name();
        if (children == null) {
            resultString += "class = " + label;
        } else {
            for (int i = 0; i < children.length; i++) {
                if (children[i] == null) {
                    resultString += tempAttributeName + " = "
                            + dataset.attribute(splitAttribute).value(i) + ":" + "class = " + label
                            + "\r\n";
                } else {
                    resultString += tempAttributeName + " = "
                            + dataset.attribute(splitAttribute).value(i) + ":" + children[i]
                            + "\r\n";
                } // Of if
            }// of for i
        }// of if
        return resultString;
    }// of toString

    public static void id3Test() {
        ID3 tempID3 = new ID3("D:\\研究生学习\\测试文件\\sampledata-main\\weather.arff");

        ID3.smallBlockThreshold = 3;
        tempID3.buildTree();
        System.out.println("The tree is: \r\n" + tempID3);

        double tempAccuracy = tempID3.selfTest();
        System.out.println("The accuracy is: " + tempAccuracy);
    }// of id3Test


    public static void main(String[] args) {
        id3Test();
    }
}// of class ID3

运行结果:

The tree is: 
Outlook = Sunny:Humidity = High:class = 0
Humidity = Normal:class = 1
Humidity = Low:class = 0

Outlook = Overcast:class = 1
Outlook = Rain:Windy = FALSE:class = 1
Windy = TRUE:class = 0


The accuracy is: 1.0
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
决策树算法是一种广泛应用于分类和回归的机器学习算法,它基于树形结构对样本进行分类或预测。决策树算法的主要思想是通过一系列的判断来对样本进行分类或预测。在决策树中,每个节点表示一个属性或特征,每个分支代表该属性或特征的一个取值,而每个叶子节点代表一个分类或预测结果。 决策树算法的训练过程主要包括以下步骤: 1. 特征选择:根据某种指标(如信息增益或基尼系数)选择最优的特征作为当前节点的分裂属性。 2. 决策树生成:根据选择的特征将数据集分成若干个子集,并递归地生成决策树。 3. 剪枝:通过剪枝操作来提高决策树的泛化性能。 决策树算法的优点包括易于理解和解释、计算复杂度较低、对缺失值不敏感等。但是,决策树算法也存在一些缺点,如容易出现过拟合、对离散数据敏感等。 下面是一个决策树算法的案例:假设我们要根据一个人的年龄、性别、教育程度和职业预测其收入水平(高于或低于50K)。首先,我们需要将这些特征进行编码,将其转换为数值型数据。然后,我们可以使用决策树算法对这些数据进行训练,并生成一个决策树模型。最后,我们可以使用该模型对新的数据进行分类或预测。例如,根据一个人的年龄、性别、教育程度和职业,我们可以使用决策树模型预测该人的收入水平。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值