决策树分类器-Java实现

决策树模型,其基本思想类似于if else的结构,即满足什么条件则将它判定为某一类,而这里的决策树的深度就类似于if else的深度。
决策树的问题焦点在于,对于一个拥有多维数据特征的数据点,如何选择合适的分类依据。例如一只鸡(两条腿,有翅膀,没有脚蹼。。。),一只鸭(两条腿,有翅膀,有脚蹼。。),等等,现在来了一只奇怪的生物(两条腿,有翅膀,没有脚蹼。。),如果先根据腿或翅膀来判断的话,根本无法判断它属于哪一种生物,而如果根据脚蹼来判断的话,立刻就能分辨出来。从这个例子中,想表达的就是决策树如果去抉择一种最合适的特征来得到不同的判决类。
本文是基于数据集信息熵最小的原则,来确定这种树的生长规则的。信息熵的背景,不多说,简而言之,越有序的系统熵越小,越无序的系统熵越大。其计算公式如下:
H(x) = E[I(xi)] = E[ log(2,1/p(xi)) ] = -∑p(xi)log(2,p(xi)) (i=1,2,..n)
其中p(xi)为xi样本在x总体中的取值概率(或统计学中的频率)。
在给出具体实现代码之前,我先给出此处用到的树结构。

/**
 * Created by Song on 2017/1/4.
 * 树节点,可序列化存储
 */
public class Node implements Serializable{
   
    public Object element;
    public Map<Object,Node> child;
}

之所以这样设计,是基于此处具体的应用环境。e在此应用环境中,element为String类型的特征名称,而Map中的每个键值对,键名代表着判决条件(连接两个节点的线的标称),值代表着下一个节点。
下面再给出,Java中对象序列化存储的部分代码(在测试时,我注释掉了),用于在通过训练集得到决策树结构之后,将该树保存在文件中,而不需要,每次都重新训练得到决策树结构。

Node root = handler.createTree(dataSet,featurelabels,labelStr);
  //树结构存储
  ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(new File("E:\\dectree.txt")));
        oos.writeObject(root);
        oos.flush();
        oos.close();
//树结构读取
        ObjectInputStream ois = new ObjectInputStream(new FileInputStream(new File("E:\\dectree.txt")));
        Node tree = (Node) ois.readObject();

下面是决策树分类器的具体实现代码:

/**
 * Created by Song on 2017/1/3.
 * 决策树
 */
public class DectreeHandler {
   
    /**
     * 计算数据集的香农熵
     * @param dataSet 数据集(最后一列为分类信息)
     * @return 香农熵
     */
    private static double calcShannonEnt(Matrix dataSet){
        int m = dataSet.getRowDimension();
        int n = dataSet.getColumnDimension();
        double currentLabel = 0;
        double shannonEnt = 0;
        double rate = 0;
        HashMap<Double,Integer> labelCounts = new HashMap<Double, Integer>();
        //统计各类出现次数
        for(int i=0;i<m;i++){
            currentLabel = dataSet.get(i,n-1);
            if(!labelCounts.containsKey(currentLabel))
                labelCounts.put(currentLabel,0);
            labelCounts.put(currentLabel,labelCounts.get(currentLabel)+1);
        }
        //计算整体香农熵
        for(double key:labelCounts.keySet()){
            rate =labelCounts.get(key)/(float)m;
            shannonEnt -= rate*Math.log(rate)/Math.log(2);
        }
        return shannonEnt;
    }

    /**
  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
决策树算法是一种常见的机器学习算法,可以用于分类和回归问题。在 Java 中,实现决策树算法需要使用递归和面向对象的编程思想。 下面是一个简单的决策树分类器Java 实现: ```java import java.util.ArrayList; import java.util.HashMap; import java.util.Map; public class DecisionTree { private Node root; public DecisionTree() {} public void fit(ArrayList<ArrayList<String>> data, ArrayList<String> labels) { this.root = buildTree(data, labels); } public String predict(ArrayList<String> sample) { return classify(sample, this.root); } private String classify(ArrayList<String> sample, Node node) { if (node.isLeaf()) { return node.getLabel(); } String feature = node.getFeature(); String value = sample.get(node.getIndex(feature)); Node child = node.getChildren().get(value); return classify(sample, child); } private Node buildTree(ArrayList<ArrayList<String>> data, ArrayList<String> labels) { if (labels.isEmpty()) { return new Node(getMajorityLabel(labels)); } if (isHomogeneous(labels)) { return new Node(labels.get(0)); } if (data.isEmpty()) { return new Node(getMajorityLabel(labels)); } String feature = getBestFeature(data, labels); Node node = new Node(feature); for (String value : getUniqueValues(data, feature)) { ArrayList<ArrayList<String>> subset = getSubset(data, labels, feature, value); Node child = buildTree(subset, getSubsetLabels(labels, subset)); node.addChild(value, child); } return node; } private ArrayList<String> getSubsetLabels(ArrayList<String> labels, ArrayList<ArrayList<String>> subset) { ArrayList<String> subsetLabels = new ArrayList<>(); for (ArrayList<String> sample : subset) { subsetLabels.add(labels.get(data.indexOf(sample))); } return subsetLabels; } private ArrayList<ArrayList<String>> getSubset(ArrayList<ArrayList<String>> data, ArrayList<String> labels, String feature, String value) { ArrayList<ArrayList<String>> subset = new ArrayList<>(); for (int i = 0; i < data.size(); i++) { ArrayList<String> sample = data.get(i); if (sample.get(getIndex(feature)).equals(value)) { subset.add(sample); } } return subset; } private ArrayList<String> getUniqueValues(ArrayList<ArrayList<String>> data, String feature) { ArrayList<String> uniqueValues = new ArrayList<>(); int index = getIndex(feature); for (ArrayList<String> sample : data) { String value = sample.get(index); if (!uniqueValues.contains(value)) { uniqueValues.add(value); } } return uniqueValues; } private int getIndex(String feature) { return this.root.getFeatures().indexOf(feature); } private String getBestFeature(ArrayList<ArrayList<String>> data, ArrayList<String> labels) { double maxGain = -1; String bestFeature = null; double parentEntropy = getEntropy(labels); for (String feature : this.root.getFeatures()) { double gain = parentEntropy - getConditionalEntropy(data, labels, feature); if (gain > maxGain) { maxGain = gain; bestFeature = feature; } } return bestFeature; } private double getConditionalEntropy(ArrayList<ArrayList<String>> data, ArrayList<String> labels, String feature) { double conditionalEntropy = 0; Map<String, ArrayList<String>> subsets = getSubsets(data, feature); for (String value : subsets.keySet()) { ArrayList<String> subsetLabels = getSubsetLabels(labels, subsets.get(value)); double probability = (double) subsets.get(value).size() / data.size(); conditionalEntropy += probability * getEntropy(subsetLabels); } return conditionalEntropy; } private Map<String, ArrayList<String>> getSubsets(ArrayList<ArrayList<String>> data, String feature) { Map<String, ArrayList<String>> subsets = new HashMap<>(); int index = getIndex(feature); for (ArrayList<String> sample : data) { String value = sample.get(index); if (!subsets.containsKey(value)) { subsets.put(value, new ArrayList<>()); } subsets.get(value).add(sample); } return subsets; } private double getEntropy(ArrayList<String> labels) { double entropy = 0; Map<String, Integer> counts = getCounts(labels); for (Integer count : counts.values()) { double probability = (double) count / labels.size(); entropy -= probability * Math.log(probability) / Math.log(2); } return entropy; } private String getMajorityLabel(ArrayList<String> labels) { Map<String, Integer> counts = getCounts(labels); int maxCount = -1; String majorityLabel = null; for (String label : counts.keySet()) { int count = counts.get(label); if (count > maxCount) { maxCount = count; majorityLabel = label; } } return majorityLabel; } private boolean isHomogeneous(ArrayList<String> labels) { String firstLabel = labels.get(0); for (String label : labels) { if (!label.equals(firstLabel)) { return false; } } return true; } private Map<String, Integer> getCounts(ArrayList<String> labels) { Map<String, Integer> counts = new HashMap<>(); for (String label : labels) { if (!counts.containsKey(label)) { counts.put(label, 0); } counts.put(label, counts.get(label) + 1); } return counts; } private class Node { private String feature; private ArrayList<String> features; private String label; private Map<String, Node> children; public Node(String feature) { this.feature = feature; this.children = new HashMap<>(); } public Node(String label) { this.label = label; } public String getFeature() { return this.feature; } public ArrayList<String> getFeatures() { return this.features; } public String getLabel() { return this.label; } public Map<String, Node> getChildren() { return this.children; } public boolean isLeaf() { return this.label != null; } public void addChild(String value, Node child) { this.children.put(value, child); } } } ``` 在这个实现中,`DecisionTree` 类包含了决策树的构建、训练和预测方法。`Node` 类表示决策树节点,包含了节点的特征、标签和子节点等信息。这里使用了递归的方法构建决策树,每次递归都会选择最佳的特征进行划分,直到满足停止条件为止。 下面是一个使用上述决策树分类器的例子: ```java public static void main(String[] args) { ArrayList<ArrayList<String>> data = new ArrayList<>(); data.add(new ArrayList<>(Arrays.asList("sunny", "hot", "high", "weak"))); data.add(new ArrayList<>(Arrays.asList("sunny", "hot", "high", "strong"))); data.add(new ArrayList<>(Arrays.asList("overcast", "hot", "high", "weak"))); data.add(new ArrayList<>(Arrays.asList("rainy", "mild", "high", "weak"))); data.add(new ArrayList<>(Arrays.asList("rainy", "cool", "normal", "weak"))); data.add(new ArrayList<>(Arrays.asList("rainy", "cool", "normal", "strong"))); data.add(new ArrayList<>(Arrays.asList("overcast", "cool", "normal", "strong"))); data.add(new ArrayList<>(Arrays.asList("sunny", "mild", "high", "weak"))); data.add(new ArrayList<>(Arrays.asList("sunny", "cool", "normal", "weak"))); data.add(new ArrayList<>(Arrays.asList("rainy", "mild", "normal", "weak"))); data.add(new ArrayList<>(Arrays.asList("sunny", "mild", "normal", "strong"))); data.add(new ArrayList<>(Arrays.asList("overcast", "mild", "high", "strong"))); data.add(new ArrayList<>(Arrays.asList("overcast", "hot", "normal", "weak"))); data.add(new ArrayList<>(Arrays.asList("rainy", "mild", "high", "strong"))); ArrayList<String> labels = new ArrayList<>(Arrays.asList("no", "no", "yes", "yes", "yes", "no", "yes", "no", "yes", "yes", "yes", "yes", "yes", "no")); DecisionTree dt = new DecisionTree(); dt.fit(data, labels); ArrayList<String> sample = new ArrayList<>(Arrays.asList("sunny", "hot", "high", "weak")); String prediction = dt.predict(sample); System.out.println(prediction); } ``` 这个例子中,我们使用了一个简单的天气数据集,包含了天气状况和是否打高尔夫的标签。我们先构建了一个 `DecisionTree` 对象,然后调用 `fit` 方法进行训练,最后使用 `predict` 方法对新样本进行预测。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值