决策树是一种用于分类和回归的监督学习方法。决策树目标是创建一个模型,通过学习从数据特征推断出的简单决策规则来预测目标变量的值。
决策树优缺点
决策树的一些优点是:
易于理解和解释。树可以被可视化。
需要很少的训练数据
能处理数值和类别数据
能够处理多输出问题
决策树的一些缺点是:
深度太深,很容易过拟合
决策树可能不稳定
决策树的预测结果不是连续的
决策树节点分裂过程是贪心的
sklearn 决策树API
DecisionTreeClassifier
https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.htm
二分类或多分类、多标签分类
from sklearn.datasets import load_iris
from sklearn.model_selection import cross_val_score
from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(random_state=0)
iris = load_iris()
cross_val_score(clf, iris.data, iris.target, cv=10)
DecisionTreeRegressor
https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeRegressor.html
回归、多标签回归
from sklearn.datasets import load_diabetes
from sklearn.model_selection import cross_val_score
from sklearn.tree import DecisionTreeRegressor
X, y = load_diabetes(return_X_y=True)
regressor = DecisionTreeRegressor(random_state=0)
cross_val_score(regressor, X, y, cv=10)
sklearn 底层树结构
树结构
决策分类器有一个名为的属性tree_
,它允许访问低级属性,例如node_count
(节点总数),和max_depth
(树的最大深度),它还存储整个二叉树结构,表示为多个并行数组。
children_left[i]
: 节点的左子节点的 idi
,如果是叶节点则为 -1children_right[i]
: 节点右子节点的idi
,如果是叶节点则为-1feature[i]
: 用于分裂节点的特征i
threshold[i]
:节点的阈值i
n_node_samples[i]
:到达节点的训练样本数i
impurity[i]
:节点处的杂质i
iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
clf = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)
clf.fit(X_train, y_train)
n_nodes = clf.tree_.node_count
children_left = clf.tree_.children_left
children_right = clf.tree_.children_right
feature = clf.tree_.feature
threshold = clf.tree_.threshold
决策路径
decision_path
方法输出一个指示矩阵,允许检索感兴趣的样本遍历的节点。位置处的指示矩阵中的非零元素表示样本经过节点。
apply
方法返回样本所达到的叶id,得到样本所达到的叶子节点 ID 的数组,这可以用对样本进行编码,也可以用于特征工程。
iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
clf = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)
clf.fit(X_train, y_train)
node_indicator = clf.decision_path(X_test)
leaf_id = clf.apply(X_test)
参考资料
https://scikit-learn.org/stable/modules/tree.html
https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html
https://scikit-learn.org/stable/auto_examples/tree/
往期精彩回顾
适合初学者入门人工智能的路线及资料下载(图文+视频)机器学习入门系列下载机器学习及深度学习笔记等资料打印《统计学习方法》的代码复现专辑机器学习交流qq群955171419,加入微信群请扫码