对随机森林进行可视化
安装一些需要的库:
pip install graphviz
pip install pydotplus
在Jupyter notebook 中进行随机森林可视化:
from sklearn import datasets
from sklearn.ensemble import RandomForestClassifier
from IPython.core.display import HTML, display
from sklearn import tree
import pydotplus
# 使用自带的iris数据
iris = datasets.load_iris()
X = iris.data
y = iris.target
# 训练模型,限制树的最大深度4
clf = RandomForestClassifier(max_depth=4)
#拟合模型
clf.fit(X, y)
estimators = clf.estimators_
for m in estimators:
dot_data = tree.export_graphviz(m, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
# 使用ipython的终端jupyter notebook显示。
svg = graph.create_svg()
if hasattr(svg, "decode"):
svg = svg.decode("utf-8")
html = HTML(svg)
display(html)
结果图就不放了