【机器学习】监督学习算法 2.3.5 决策树

全文代码如下

需要在github上下载相关数据集,下载整个包,在data中找到ram_prices.csv即可

点这下载

#决策树

import mglearn
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

cancer = load_breast_cancer()
x_train,x_test,y_train,y_test = train_test_split(cancer.data,cancer.target,stratify=cancer.target,random_state=42)
tree = DecisionTreeClassifier(random_state=0)
tree.fit(x_train,y_train)
print('accuracy on training set:{:.3f}'.format(tree.score(x_train,y_train)))
print('accuracy on test set:{:.3f}'.format(tree.score(x_test,y_test)))
 
#树的深度为4
tree = DecisionTreeClassifier(max_depth=4,random_state=0)
tree.fit(x_train,y_train)
print('accuracy on training set:{:.3f}'.format(tree.score(x_train,y_train)))
print('accuracy on test set:{:.3f}'.format(tree.score(x_test,y_test)))

from sklearn.tree import export_graphviz
import graphviz

export_graphviz(tree,out_file='tree.dot',class_names=['malignant','benign'],feature_names=cancer.feature_names,impurity=False,filled=True)
with open('tree.dot') as f:
    dot_graph = f.read()
graphviz.Source(dot_graph)

print('feature importances:{}'.format(tree.feature_importances_))

def plot_feature_importances_cancer(model):
    n_features = cancer.data.shape[1]
    plt.barh(range(n_features),model.feature_importances_,align='center')
    plt.yticks(np.arange(n_features),cancer.feature_names)
    plt.xlabel('feature importance')
    plt.ylabel("feature")
    plt.show()

from IPython import display
plot_feature_importances_cancer(tree)

tree = mglearn.plots.plot_tree_not_monotone()
display.display(tree)
plt.show()

#计算机内存价格
ram_prices = pd.read_csv("ram_price.csv")

plt.semilogy(ram_prices.date,ram_prices.price)
plt.xlabel('year')
plt.ylabel('price in $/mbyte')
plt.show()

from sklearn.tree import DecisionTreeRegressor
from sklearn.linear_model import LinearRegression

data_train = ram_prices[ram_prices.date < 2000]
data_test = ram_prices[ram_prices.date >= 2000]

x_train = data_train.date[:,np.newaxis]
y_train = np.log(data_train.price)

tree = DecisionTreeRegressor().fit(x_train,y_train)
linear_reg = LinearRegression().fit(x_train,y_train)

x_all = ram_prices.date[:,np.newaxis]

pred_tree = tree.predict(x_all)
pred_lr = linear_reg.predict(x_all)

price_tree = np.exp(pred_tree)
price_lr = np.exp(pred_lr)

plt.semilogy(data_train.date,data_train.price,label='training data')
plt.semilogy(data_test.date,data_test.price,label='test data')
plt.semilogy(ram_prices.date,price_tree,label='tree prediction')
plt.semilogy(ram_prices.date,price_lr,label='linear prediction')
plt.legend()
plt.show()


在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值