使用cart算法进行分类

        这里分类的文件是鸢尾花中的第0类和第1类还有第2类。这里使用的是cart算法,求解Gini系数来分类。每一类均用了十个数据来测试,剩余的数据全部拿来做了训练集。测试出来的效果还不错,所以并没有进行剪枝操作。至于理论部分网上有许多优秀的博客可以参考学习。

import numpy as np
import pandas as pd
from typing import List, Self

class Node:
    def __init__(self,
                 feature_index: int | None = None,
                 threshold: float | None = None,
                 left: Self | None = None,
                 right: Self | None = None,
                 value: int | None = None
    ):
        self.feature_index = feature_index  # 用于分割数据的特征索引
        self.threshold = threshold  # 分割阈值
        self.left = left  # 左子树
        self.right = right  # 右子树
        self.value = value  # 叶子节点的预测值

def get_gini(x: np.ndarray[float, float], y: np.ndarray[int]) -> float:
    unique_labels = np.unique(y)
    num_classes = len(unique_labels)
    rows = len(x)

    sorted_data = sorted(zip(x, y), key=lambda pair: pair[0])
    thresholds: list[float] = [(sorted_data[i][0] + sorted_data[i+1][0]) / 2 for i in range(rows-1)]
    gini_values: list[float] = []
    for threshold in thresholds:
        count_left = np.zeros(num_classes)
        count_right = np.zeros(num_classes)

        for i in range(rows):
            if x[i] < threshold:
                index = np.where(unique_labels == y[i])[0][0]
                count_left[index] += 1
            else:
                index = np.where(unique_labels == y[i])[0][0]
                count_right[index] += 1

        p_left = 0 if (left_count := np.sum(count_left)) == 0 else count_left / left_count
        p_right = 0 if (right_count := np.sum(count_right)) == 0 else count_right / right_count
        gini_left = 1 - np.sum(p_left ** 2)
        gini_right = 1 - np.sum(p_right ** 2)
        gini: float = (np.sum(count_left) / rows) * gini_left + (np.sum(count_right) / rows) * gini_right
        gini_values.append(gini)
    min_gini_index = np.argmin(gini_values)
    corresponding_threshold = thresholds[min_gini_index]
    return corresponding_threshold

def build_tree(node: Node, x: np.ndarray[float, float], y: np.ndarray[int], depth: int, max_depth: int):
    if depth >= max_depth - 1 or len(np.unique(y)) == 1:
        node.value = np.bincount(y).argmax()  # 叶子节点的预测值为平均值
        return

    best_feature = 0
    best_threshold = 0
    best_gini = float('inf')
    for i in range(x.shape[1]):
        threshold = get_gini(x[:, i], y)  # 使用基尼系数选择最佳分割点
        if threshold < best_gini:
            best_gini = threshold
            best_feature = i
            best_threshold = threshold

    left_indices: np.ndarray[bool] = x[:, best_feature] < best_threshold
    right_indices = ~left_indices

    node.feature_index = best_feature
    node.threshold = best_threshold

    if left_indices.any():
        node.left = Node()
        build_tree(node.left, x[left_indices, :], y[left_indices], depth+1, max_depth)
    if right_indices.any():
        node.right = Node()
        build_tree(node.right, x[right_indices, :], y[right_indices], depth+1, max_depth)

def predict_tree(node: Node, x: np.ndarray[float]) -> int:
    if node.value is not None:
        return node.value

    if x[node.feature_index] < node.threshold:
        return predict_tree(node.left, x)
    return predict_tree(node.right, x)


data = pd.read_excel("C:/Users/wxc/PycharmProjects/pythonProject4/机器学习/回归/决策树/train.xlsx") #训练集进行训练,求得决策树
x_train = np.array(data.iloc[:, 1:5])
y_train = np.array(data.iloc[:, 6])


tree_root = Node()
build_tree(tree_root, x_train, y_train, 0, 4)  # 递归构建决策树

data1 = pd.read_excel("C:/Users/wxc/PycharmProjects/pythonProject4/机器学习/回归/决策树/test.xlsx",header=None) #测试集进行测试
x_test = np.array(data1.iloc[:, 1:5])
prediction = [predict_tree(tree_root, x) for x in x_test]
y_test = np.array(data1.iloc[:, 6])
print("预测值为:", prediction)
c = 0
for i in range(len(y_test)):
    if y_test[i] == prediction[i]:
        c= c+1
print('准确率')        
print(c/(len(y_test)))

  • 7
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
以下是一个简单的使用 CART 决策树算法的 Java 代码示例: ```java import java.util.ArrayList; import java.util.HashMap; public class CART { public static void main(String[] args) { // 构造训练集数据 ArrayList<HashMap<String, String>> trainData = new ArrayList<>(); HashMap<String, String> data1 = new HashMap<>(); data1.put("age", "青年"); data1.put("income", "高"); data1.put("student", "否"); data1.put("credit_rating", "一般"); data1.put("class", "不放贷"); trainData.add(data1); HashMap<String, String> data2 = new HashMap<>(); data2.put("age", "青年"); data2.put("income", "高"); data2.put("student", "否"); data2.put("credit_rating", "好"); data2.put("class", "不放贷"); trainData.add(data2); HashMap<String, String> data3 = new HashMap<>(); data3.put("age", "中年"); data3.put("income", "高"); data3.put("student", "否"); data3.put("credit_rating", "好"); data3.put("class", "放贷"); trainData.add(data3); HashMap<String, String> data4 = new HashMap<>(); data4.put("age", "中年"); data4.put("income", "中等"); data4.put("student", "否"); data4.put("credit_rating", "好"); data4.put("class", "放贷"); trainData.add(data4); HashMap<String, String> data5 = new HashMap<>(); data5.put("age", "中年"); data5.put("income", "中等"); data5.put("student", "是"); data5.put("credit_rating", "一般"); data5.put("class", "放贷"); trainData.add(data5); HashMap<String, String> data6 = new HashMap<>(); data6.put("age", "老年"); data6.put("income", "中等"); data6.put("student", "是"); data6.put("credit_rating", "好"); data6.put("class", "放贷"); trainData.add(data6); HashMap<String, String> data7 = new HashMap<>(); data7.put("age", "老年"); data7.put("income", "低"); data7.put("student", "是"); data7.put("credit_rating", "好"); data7.put("class", "不放贷"); trainData.add(data7); HashMap<String, String> data8 = new HashMap<>(); data8.put("age", "老年"); data8.put("income", "低"); data8.put("student", "否"); data8.put("credit_rating", "一般"); data8.put("class", "不放贷"); trainData.add(data8); // 训练决策树模型 DecisionTreeModel model = train(trainData); System.out.println("决策树模型:" + model); // 预测新数据 HashMap<String, String> newData = new HashMap<>(); newData.put("age", "青年"); newData.put("income", "中等"); newData.put("student", "否"); newData.put("credit_rating", "一般"); String result = predict(newData, model); System.out.println("新数据预测结果:" + result); } /** * 训练决策树模型 * @param trainData 训练集数据 * @return 决策树模型 */ public static DecisionTreeModel train(ArrayList<HashMap<String, String>> trainData) { // 获取训练集属性列表 ArrayList<String> attributeList = new ArrayList<>(); for (String key : trainData.get(0).keySet()) { attributeList.add(key); } // 构建决策树模型 DecisionTreeModel model = new DecisionTreeModel(); buildDecisionTree(trainData, attributeList, model); return model; } /** * 构建决策树 * @param trainData 训练集数据 * @param attributeList 属性列表 * @param model 决策树模型 */ public static void buildDecisionTree(ArrayList<HashMap<String, String>> trainData, ArrayList<String> attributeList, DecisionTreeModel model) { // 如果训练集中所有数据属于同一类别,则将当前节点设置为叶子节点,并返回 boolean isSameClass = true; String firstClass = trainData.get(0).get("class"); for (HashMap<String, String> data : trainData) { if (!data.get("class").equals(firstClass)) { isSameClass = false; break; } } if (isSameClass) { model.isLeaf = true; model.className = firstClass; return; } // 如果属性列表为空,则将当前节点设置为叶子节点,并将其类别设置为训练集中最常见的类别 if (attributeList.isEmpty()) { model.isLeaf = true; model.className = getMostCommonClass(trainData); return; } // 选择最佳属性(即使得信息增益最大的属性) String bestAttribute = getBestAttribute(trainData, attributeList); model.attributeName = bestAttribute; // 根据最佳属性分裂训练集 ArrayList<ArrayList<HashMap<String, String>>> splitDataList = splitData(trainData, bestAttribute); // 递归构建子树 ArrayList<String> newAttributeList = new ArrayList<>(attributeList); newAttributeList.remove(bestAttribute); // 在属性列表中删除已经使用的属性 for (ArrayList<HashMap<String, String>> splitData : splitDataList) { DecisionTreeModel subModel = new DecisionTreeModel(); model.subModelList.add(subModel); buildDecisionTree(splitData, newAttributeList, subModel); } } /** * 预测新数据 * @param newData 新数据 * @param model 决策树模型 * @return 预测结果 */ public static String predict(HashMap<String, String> newData, DecisionTreeModel model) { // 如果当前节点是叶子节点,则返回其类别 if (model.isLeaf) { return model.className; } // 根据当前节点的属性进行分裂 String attributeValue = newData.get(model.attributeName); for (DecisionTreeModel subModel : model.subModelList) { if (subModel.attributeValue.equals(attributeValue)) { return predict(newData, subModel); } } // 如果当前节点没有与新数据匹配的子节点,则将其类别设置为训练集中最常见的类别 return getMostCommonClass(model.trainData); } /** * 获取训练集中最常见的类别 * @param trainData 训练集数据 * @return 最常见的类别 */ public static String getMostCommonClass(ArrayList<HashMap<String, String>> trainData) { HashMap<String, Integer> classCountMap = new HashMap<>(); for (HashMap<String, String> data : trainData) { String className = data.get("class"); if (classCountMap.containsKey(className)) { classCountMap.put(className, classCountMap.get(className) + 1); } else { classCountMap.put(className, 1); } } String mostCommonClass = ""; int maxCount = -1; for (String className : classCountMap.keySet()) { int count = classCountMap.get(className); if (count > maxCount) { mostCommonClass = className; maxCount = count; } } return mostCommonClass; } /** * 获取训练集中最佳属性 * @param trainData 训练集数据 * @param attributeList 属性列表 * @return 最佳属性 */ public static String getBestAttribute(ArrayList<HashMap<String, String>> trainData, ArrayList<String> attributeList) { String bestAttribute = ""; double maxInformationGain = -1; for (String attribute : attributeList) { double informationGain = calculateInformationGain(trainData, attribute); if (informationGain > maxInformationGain) { bestAttribute = attribute; maxInformationGain = informationGain; } } return bestAttribute; } /** * 根据指定属性值分裂训练集 * @param trainData 训练集数据 * @param attributeName 属性名称 * @return 分裂后的数据集列表 */ public static ArrayList<ArrayList<HashMap<String, String>>> splitData(ArrayList<HashMap<String, String>> trainData, String attributeName) { ArrayList<ArrayList<HashMap<String, String>>> splitDataList = new ArrayList<>(); for (HashMap<String, String> data : trainData) { String attributeValue = data.get(attributeName); boolean isSplitDataExist = false; for (ArrayList<HashMap<String, String>> splitData : splitDataList) { if (splitData.get(0).get(attributeName).equals(attributeValue)) { splitData.add(data); isSplitDataExist = true; break; } } if (!isSplitDataExist) { ArrayList<HashMap<String, String>> newSplitData = new ArrayList<>(); newSplitData.add(data); splitDataList.add(newSplitData); } } for (ArrayList<HashMap<String, String>> splitData : splitDataList) { if (splitData.size() > 0) { String attributeValue = splitData.get(0).get(attributeName); DecisionTreeModel subModel = new DecisionTreeModel(); subModel.attributeName = attributeName; subModel.attributeValue = attributeValue; subModel.trainData = splitData; } } return splitDataList; } /** * 计算指定属性的信息增益 * @param trainData 训练集数据 * @param attributeName 属性名称 * @return 信息增益 */ public static double calculateInformationGain(ArrayList<HashMap<String, String>> trainData, String attributeName) { // 计算训练集的熵 double entropy = calculateEntropy(trainData); // 计算分裂后的熵 double splitEntropy = 0; ArrayList<ArrayList<HashMap<String, String>>> splitDataList = splitData(trainData, attributeName); for (ArrayList<HashMap<String, String>> splitData : splitDataList) { double splitDataEntropy = calculateEntropy(splitData); double splitDataProbability = (double) splitData.size() / trainData.size(); splitEntropy += splitDataEntropy * splitDataProbability; } // 计算信息增益 double informationGain = entropy - splitEntropy; return informationGain; } /** * 计算数据集的熵 * @param dataList 数据集 * @return 熵 */ public static double calculateEntropy(ArrayList<HashMap<String, String>> dataList) { HashMap<String, Integer> classCountMap = new HashMap<>(); for (HashMap<String, String> data : dataList) { String className = data.get("class"); if (classCountMap.containsKey(className)) { classCountMap.put(className, classCountMap.get(className) + 1); } else { classCountMap.put(className, 1); } } double entropy = 0; for (String className : classCountMap.keySet()) { double probability = (double) classCountMap.get(className) / dataList.size(); entropy -= probability * Math.log(probability) / Math.log(2); } return entropy; } } /** * 决策树模型 */ class DecisionTreeModel { public boolean isLeaf; // 是否是叶子节点 public String attributeName; // 分裂属性名称 public String attributeValue; // 分裂属性值 public ArrayList<DecisionTreeModel> subModelList; // 子模型列表 public String className; // 类别名称 public ArrayList<HashMap<String, String>> trainData; // 训练集数据 public DecisionTreeModel() { this.isLeaf = false; this.attributeName = ""; this.attributeValue = ""; this.subModelList = new ArrayList<>(); this.className = ""; this.trainData = new ArrayList<>(); } public String toString() { StringBuilder sb = new StringBuilder(); if (isLeaf) { sb.append(className); } else { sb.append(attributeName + " -> "); for (DecisionTreeModel subModel : subModelList) { sb.append(subModel.attributeValue + ": " + subModel.toString() + "; "); } } return sb.toString(); } } ``` 这个示例代码实现了一个简化的 CART 决策树算法,并提供了训练和预测的方法。由于数据集比较小,所以没有进行剪枝等优化操作。在实际应用中,可以根据具体情况进行改进。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值