import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.inspection import permutation_importance
import matplotlib
from sklearn.datasets import load_iris
from sklearn.tree import plot_tree
from sklearn.tree import export_graphviz
from sklearn.tree import export_text
iris = load_iris() #加载鸢尾花数据集
data = iris.data #特征数据
target = iris.target #分类数据
params = {'n_estimators': 2, # 弱分类器的个数
'max_depth': 2, # 弱分类器(CART回归树)的最大深度
'learning_rate': 0.1 }
GBDTreg = GradientBoostingClassifier(**params)
GBDTreg.fit(data, target)
for ii in range(0,GBDTreg.n_estimators):
for jj in range(0,3):
sub_tree = GBDTreg.estimators_[ii,jj] # GBDTreg.estimators_.shape = (2,3)
plt.figure(figsize = (15,9))
plot_tree(sub_tree)
r1 = export_graphviz(sub_tree)
print(r1)
r2 = export_text(sub_tree, feature_names=iris['feature_names'])
print(r2)
y_predict = GBDTreg.predict(data)
结果如下:
plot_tree最后一棵决策树
打印export_graphviz最后一棵树
digraph Tree {
node [shape=box] ;
0 [label="X[3] <= 1.65\nfriedman_mse = 0.183\nsamples = 150\nvalue = 0.0"] ;
1 [label="X[2] <= 4.95\nfriedman_mse = 0.034\nsamples = 102\nvalue = -0.264"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="friedman_mse = 0.0\nsamples = 97\nvalue = -0.953"] ;
1 -> 2 ;
3 [label="friedman_mse = 0.16\nsamples = 5\nvalue = 1.34"] ;
1 -> 3 ;
4 [label="X[3] <= 1.75\nfriedman_mse = 0.037\nsamples = 48\nvalue = 0.561"] ;
0 -> 4 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
5 [label="friedman_mse = 0.272\nsamples = 2\nvalue = 0.537"] ;
4 -> 5 ;
6 [label="friedman_mse = 0.02\nsamples = 46\nvalue = 1.604"] ;
4 -> 6 ;
}
打印export_text第一棵树
|--- petal width (cm) <= 1.65
| |--- petal length (cm) <= 4.95
| | |--- value: [-0.95]
| |--- petal length (cm) > 4.95
| | |--- value: [1.34]
|--- petal width (cm) > 1.65
| |--- petal width (cm) <= 1.75
| | |--- value: [0.54]
| |--- petal width (cm) > 1.75
| | |--- value: [1.60]