机器篇——决策树(一) 简要介绍一下决策树

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
决策树是一种广泛使用的分类和回归方法,它可以用于多维分类问题。以下是使用C++类实现决策树的示例。 首先,我们需要定义一个节点类来表示决策树的节点: ``` class Node{ public: int feature; //特征编号 double threshold; //阈值 int label; //标签 Node* left; //左子树 Node* right; //右子树 Node() { feature = 0; threshold = 0.0; label = -1; left = NULL; right = NULL; } }; ``` 然后,我们需要定义一个决策树类来构建决策树: ``` class DecisionTree { public: DecisionTree(); ~DecisionTree(); void buildTree(const vector<vector<double>>& data, const vector<int>& labels); int predict(const vector<double>& data) const; private: Node* root; void destroy(Node* node); int getMajorityLabel(const vector<int>& labels) const; int getBestFeature(const vector<vector<double>>& data, const vector<int>& labels, vector<double>& thresholds) const; Node* buildSubTree(const vector<vector<double>>& data, const vector<int>& labels, const vector<double>& thresholds); }; ``` 其中,buildTree()函数用于构建决策树,predict()函数用于进行预测。 下面是buildTree()函数的实现: ``` void DecisionTree::buildTree(const vector<vector<double>>& data, const vector<int>& labels) { vector<double> thresholds(data[0].size(), 0.0); root = buildSubTree(data, labels, thresholds); } Node* DecisionTree::buildSubTree(const vector<vector<double>>& data, const vector<int>& labels, const vector<double>& thresholds) { Node* node = new Node; if (labels.empty()) { node->label = -1; return node; } int majorityLabel = getMajorityLabel(labels); if (majorityLabel == -1) { node->label = majorityLabel; return node; } int bestFeature = getBestFeature(data, labels, thresholds); if (bestFeature == -1) { node->label = majorityLabel; return node; } node->feature = bestFeature; node->threshold = thresholds[bestFeature]; vector<vector<double>> leftData; vector<int> leftLabels; vector<vector<double>> rightData; vector<int> rightLabels; for (int i = 0; i < data.size(); i++) { if (data[i][bestFeature] <= thresholds[bestFeature]) { leftData.push_back(data[i]); leftLabels.push_back(labels[i]); } else { rightData.push_back(data[i]); rightLabels.push_back(labels[i]); } } if (leftData.empty() || rightData.empty()) { node->label = majorityLabel; return node; } node->left = buildSubTree(leftData, leftLabels, thresholds); node->right = buildSubTree(rightData, rightLabels, thresholds); return node; } ``` 在buildSubTree()函数中,我们首先判断标签是否为空,如果为空,则返回一个空节点。然后,我们计算出标签数据中出现最多的标签,并将其作为节点的标签。接下来,我们选择最佳特征和阈值来划分数据。如果无法找到最佳特征,则返回一个具有多数标签的叶子节点。如果数据无法划分,则返回一个具有多数标签的叶子节点。否则,我们将数据分成左子树和右子树,并递归构建它们。 下面是predict()函数的实现: ``` int DecisionTree::predict(const vector<double>& data) const { Node* node = root; while (node->left != NULL && node->right != NULL) { if (data[node->feature] <= node->threshold) { node = node->left; } else { node = node->right; } } return node->label; } ``` 在predict()函数中,我们从根节点开始遍历决策树,并根据特征的值和阈值选择左子树或右子树,直到到达叶子节点。叶子节点的标签就是预测结果。 最后,我们需要定义一些辅助函数,如计算数据中出现最多的标签、选择最佳特征和阈值等等。这些函数的实现可以参考以下代码: ``` int DecisionTree::getMajorityLabel(const vector<int>& labels) const { int numLabels = labels.size(); if (numLabels == 0) { return -1; } unordered_map<int, int> labelCounts; for (int i = 0; i < numLabels; i++) { if (labelCounts.find(labels[i]) != labelCounts.end()) { labelCounts[labels[i]]++; } else { labelCounts[labels[i]] = 1; } } int majorityLabel = -1; int maxCount = -1; for (auto it = labelCounts.begin(); it != labelCounts.end(); it++) { if (it->second > maxCount) { maxCount = it->second; majorityLabel = it->first; } } return majorityLabel; } int DecisionTree::getBestFeature(const vector<vector<double>>& data, const vector<int>& labels, vector<double>& thresholds) const { int numFeatures = data[0].size(); int numLabels = labels.size(); double maxGain = -1.0; int bestFeature = -1; vector<double> featureValues(numLabels); for (int i = 0; i < numFeatures; i++) { for (int j = 0; j < numLabels; j++) { featureValues[j] = data[j][i]; } sort(featureValues.begin(), featureValues.end()); double threshold; for (int j = 0; j < numLabels - 1; j++) { threshold = (featureValues[j] + featureValues[j + 1]) / 2.0; vector<int> leftLabels; vector<int> rightLabels; for (int k = 0; k < numLabels; k++) { if (data[k][i] <= threshold) { leftLabels.push_back(labels[k]); } else { rightLabels.push_back(labels[k]); } } double entropy = 0.0; if (!leftLabels.empty()) { double leftProb = (double)leftLabels.size() / numLabels; entropy += -1.0 * leftProb * log2(leftProb); } if (!rightLabels.empty()) { double rightProb = (double)rightLabels.size() / numLabels; entropy += -1.0 * rightProb * log2(rightProb); } double gain = entropy; if (gain > maxGain) { maxGain = gain; bestFeature = i; thresholds[i] = threshold; } } } return bestFeature; } void DecisionTree::destroy(Node* node) { if (node == NULL) { return; } destroy(node->left); destroy(node->right); delete node; } DecisionTree::DecisionTree() { root = NULL; } DecisionTree::~DecisionTree() { destroy(root); } ``` 现在,我们可以使用上面定义的类来构建和使用决策树。以下是一个简单的例子: ``` int main() { // 训练数据 vector<vector<double>> data = {{1.0, 2.0}, {2.0, 1.0}, {3.0, 4.0}, {4.0, 3.0}}; // 训练标签 vector<int> labels = {0, 0, 1, 1}; // 创建决策树 DecisionTree dt; dt.buildTree(data, labels); // 预测测试数据 vector<double> testData = {3.5, 2.5}; int pred = dt.predict(testData); cout << "Prediction: " << pred << endl; // 销毁决策树 return 0; } ``` 输出结果应该是“Prediction: 1”,表示测试数据属于标签1。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值