#导入模块
import matplotlib.pyplot as plt
import pandas as pd
#导入sklearn内置的房价数据集
from sklearn.datasets.california_housing import fetch_california_housing
housing = fetch_california_housing()
#输出数据集描述信息
print(housing.DESCR)
#查看数据集形状 20640条数据,每条数据8个特征
# print(housing.data.shape)
#使用sklearn导入决策树模块
from sklearn import tree
#建立决策树对象 传入参数指定树的最大深度为2
dtr = tree.DecisionTreeRegressor(max_depth = 2)
#传入特征值和目标值,这里只使用第6和7个特征
dtr.fit(housing.data[:, [6, 7]], housing.target)
#要可视化显示 首先需要安装 graphviz http://www.graphviz.org/Download..php
#创建dot对象
dot_data = \
tree.export_graphviz(
dtr,
out_file = None,
feature_names = housing.feature_names[6:8],
filled = True,
impurity = False,
rounded = True
)
#导入dot绘图包
import pydotplus
#创建绘图对象
graph = pydotplus.graph_from_dot_data(dot_data)
#绘制图像保存到本地
graph.write_png("./data/dtr_white_background.png")
绘制决策树为