- 本文为365天深度学习训练营 中的学习记录博客
- 原作者:K同学啊
我的环境:
●语言环境:Python 3.9.19
●编译器:Jupyter Lab
●环境:scikit-learn-1.5.1
代码目标:我们希望通过鸢尾花数据,训练一个决策树模型,之后应用该模型,可以根据花萼长度、花萼宽度、花瓣长度 预测 花瓣宽度。
- 导入数据
import pandas as pd
import numpy as np
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"
names = ['花萼-length', '花萼-width', '花瓣-length', '花瓣-width', 'class']
dataset = pd.read_csv(url, names=names)
dataset
代码输出:
花萼-length | 花萼-width | 花瓣-length | 花瓣-width | class | |
---|---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 | Iris-setosa |
1 | 4.9 | 3.0 | 1.4 | 0.2 | Iris-setosa |
2 | 4.7 | 3.2 | 1.3 | 0.2 | Iris-setosa |
3 | 4.6 | 3.1 | 1.5 | 0.2 | Iris-setosa |
4 | 5.0 | 3.6 | 1.4 | 0.2 | Iris-setosa |
... | ... | ... | ... | ... | ... |
145 | 6.7 | 3.0 | 5.2 | 2.3 | Iris-virginica |
146 | 6.3 | 2.5 | 5.0 | 1.9 | Iris-virginica |
147 | 6.5 | 3.0 | 5.2 | 2.0 | Iris-virginica |
148 | 6.2 | 3.4 | 5.4 | 2.3 | Iris-virginica |
149 | 5.9 | 3.0 | 5.1 | 1.8 | Iris-virginica |
150 rows × 5 columns
- 数据划分
X = dataset.iloc[ : ,[0,1,2]].values
Y = dataset.iloc[ : , 3].values
- 模型训练
from sklearn import tree
from sklearn.datasets import load_iris
clf = tree.DecisionTreeRegressor() # sk-learn的决策树模型
clf = clf.fit(X, Y) # 用数据训练树模型构建()
r = tree.export_text(clf)
- 模型预测结果
test_x = X[[0,1,50,51,100,101], :]
test_y = Y[[0,1,50,51,100,101]]
pred_target = clf.predict(test_x) # 预测y
df = pd.DataFrame()
df["原y"] = test_y
df["预测y"] = pred_target
- 打印结果
print("\n===模型======")
print(r)
print("\n===预测结果======")
print(df)
代码输出:
===模型======
|--- feature_2 <= 2.45
| |--- feature_1 <= 3.25
| | |--- feature_1 <= 2.60
| | | |--- value: [0.30]
| | |--- feature_1 > 2.60
| | | |--- feature_0 <= 4.85
| | | | |--- feature_0 <= 4.35
| | | | | |--- value: [0.10]
| | | | |--- feature_0 > 4.35
| | | | | |--- feature_2 <= 1.35
| | | | | | |--- value: [0.20]
| | | | | |--- feature_2 > 1.35
| | | | | | |--- feature_0 <= 4.50
| | | | | | | |--- value: [0.20]
| | | | | | |--- feature_0 > 4.50
| | | | | | | |--- feature_0 <= 4.65
| | | | | | | | |--- value: [0.20]
| | | | | | | |--- feature_0 > 4.65
| | | | | | | | |--- feature_1 <= 3.05
| | | | | | | | | |--- value: [0.20]
| | | | | | | | |--- feature_1 > 3.05
| | | | | | | | | |--- value: [0.20]
| | | |--- feature_0 > 4.85
| | | | |--- feature_2 <= 1.45
| | | | | |--- value: [0.20]
| | | | |--- feature_2 > 1.45
| | | | | |--- feature_1 <= 3.05
| | | | | | |--- value: [0.20]
| | | | | |--- feature_1 > 3.05
| | | | | | |--- value: [0.10]
| |--- feature_1 > 3.25
| | |--- feature_2 <= 1.55
| | | |--- feature_1 <= 4.30
| | | | |--- feature_1 <= 3.95
| | | | | |--- feature_1 <= 3.85
| | | | | | |--- feature_1 <= 3.65
| | | | | | | |--- feature_0 <= 5.30
| | | | | | | | |--- feature_2 <= 1.45
| | | | | | | | | |--- feature_1 <= 3.55
| | | | | | | | | | |--- feature_0 <= 4.80
| | | | | | | | | | | |--- value: [0.30]
| | | | | | | | | | |--- feature_0 > 4.80
| | | | | | | | | | | |--- truncated branch of depth 3
| | | | | | | | | |--- feature_1 > 3.55
| | | | | | | | | | |--- value: [0.20]
| | | | | | | | |--- feature_2 > 1.45
| | | | | | | | | |--- value: [0.20]
| | | | | | | |--- feature_0 > 5.30
| | | | | | | | |--- feature_0 <= 5.45
| | | | | | | | | |--- value: [0.40]
| | | | | | | | |--- feature_0 > 5.45
| | | | | | | | | |--- value: [0.20]
| | | | | | |--- feature_1 > 3.65
| | | | | | | |--- feature_0 <= 5.20
| | | | | | | | |--- feature_1 <= 3.75
| | | | | | | | | |--- value: [0.40]
| | | | | | | | |--- feature_1 > 3.75
| | | | | | | | | |--- value: [0.30]
| | | | | | | |--- feature_0 > 5.20
| | | | | | | | |--- value: [0.20]
| | | | | |--- feature_1 > 3.85
| | | | | | |--- value: [0.40]
| | | | |--- feature_1 > 3.95
| | | | | |--- feature_0 <= 5.35
| | | | | | |--- value: [0.10]
| | | | | |--- feature_0 > 5.35
| | | | | | |--- value: [0.20]
| | | |--- feature_1 > 4.30
| | | | |--- value: [0.40]
| | |--- feature_2 > 1.55
| | | |--- feature_0 <= 4.90
| | | | |--- value: [0.20]
| | | |--- feature_0 > 4.90
| | | | |--- feature_0 <= 5.05
| | | | | |--- feature_1 <= 3.45
| | | | | | |--- value: [0.40]
| | | | | |--- feature_1 > 3.45
| | | | | | |--- value: [0.60]
| | | | |--- feature_0 > 5.05
| | | | | |--- feature_1 <= 3.35
| | | | | | |--- value: [0.50]
| | | | | |--- feature_1 > 3.35
| | | | | | |--- feature_1 <= 3.60
| | | | | | | |--- value: [0.20]
| | | | | | |--- feature_1 > 3.60
| | | | | | | |--- feature_2 <= 1.65
| | | | | | | | |--- value: [0.20]
| | | | | | | |--- feature_2 > 1.65
| | | | | | | | |--- feature_0 <= 5.55
| | | | | | | | | |--- value: [0.40]
| | | | | | | | |--- feature_0 > 5.55
| | | | | | | | | |--- value: [0.30]
|--- feature_2 > 2.45
| |--- feature_2 <= 4.75
| | |--- feature_2 <= 4.15
| | | |--- feature_1 <= 2.65
| | | | |--- feature_2 <= 3.95
| | | | | |--- feature_2 <= 3.75
| | | | | | |--- feature_2 <= 3.15
| | | | | | | |--- value: [1.10]
| | | | | | |--- feature_2 > 3.15
| | | | | | | |--- value: [1.00]
| | | | | |--- feature_2 > 3.75
| | | | | | |--- feature_1 <= 2.45
| | | | | | | |--- value: [1.10]
| | | | | | |--- feature_1 > 2.45
| | | | | | | |--- value: [1.10]
| | | | |--- feature_2 > 3.95
| | | | | |--- feature_0 <= 5.90
| | | | | | |--- feature_0 <= 5.65
| | | | | | | |--- feature_1 <= 2.40
| | | | | | | | |--- value: [1.30]
| | | | | | | |--- feature_1 > 2.40
| | | | | | | | |--- value: [1.30]
| | | | | | |--- feature_0 > 5.65
| | | | | | | |--- value: [1.20]
| | | | | |--- feature_0 > 5.90
| | | | | | |--- value: [1.00]
| | | |--- feature_1 > 2.65
| | | | |--- feature_0 <= 5.75
| | | | | |--- feature_0 <= 5.40
| | | | | | |--- value: [1.40]
| | | | | |--- feature_0 > 5.40
| | | | | | |--- value: [1.30]
| | | | |--- feature_0 > 5.75
| | | | | |--- feature_2 <= 4.05
| | | | | | |--- feature_1 <= 2.75
| | | | | | | |--- value: [1.20]
| | | | | | |--- feature_1 > 2.75
| | | | | | | |--- value: [1.30]
| | | | | |--- feature_2 > 4.05
| | | | | | |--- value: [1.00]
| | |--- feature_2 > 4.15
| | | |--- feature_2 <= 4.45
| | | | |--- feature_0 <= 5.80
| | | | | |--- feature_2 <= 4.30
| | | | | | |--- feature_1 <= 2.95
| | | | | | | |--- feature_1 <= 2.80
| | | | | | | | |--- value: [1.30]
| | | | | | | |--- feature_1 > 2.80
| | | | | | | | |--- value: [1.30]
| | | | | | |--- feature_1 > 2.95
| | | | | | | |--- value: [1.20]
| | | | | |--- feature_2 > 4.30
| | | | | | |--- value: [1.20]
| | | | |--- feature_0 > 5.80
| | | | | |--- feature_1 <= 2.95
| | | | | | |--- value: [1.30]
| | | | | |--- feature_1 > 2.95
| | | | | | |--- feature_2 <= 4.30
| | | | | | | |--- value: [1.50]
| | | | | | |--- feature_2 > 4.30
| | | | | | | |--- value: [1.40]
| | | |--- feature_2 > 4.45
| | | | |--- feature_0 <= 5.15
| | | | | |--- value: [1.70]
| | | | |--- feature_0 > 5.15
| | | | | |--- feature_1 <= 3.25
| | | | | | |--- feature_1 <= 2.95
| | | | | | | |--- feature_2 <= 4.65
| | | | | | | | |--- feature_0 <= 5.85
| | | | | | | | | |--- value: [1.30]
| | | | | | | | |--- feature_0 > 5.85
| | | | | | | | | |--- feature_0 <= 6.55
| | | | | | | | | | |--- value: [1.50]
| | | | | | | | | |--- feature_0 > 6.55
| | | | | | | | | | |--- value: [1.30]
| | | | | | | |--- feature_2 > 4.65
| | | | | | | | |--- feature_1 <= 2.85
| | | | | | | | | |--- value: [1.20]
| | | | | | | | |--- feature_1 > 2.85
| | | | | | | | | |--- value: [1.40]
| | | | | | |--- feature_1 > 2.95
| | | | | | | |--- feature_2 <= 4.55
| | | | | | | | |--- value: [1.50]
| | | | | | | |--- feature_2 > 4.55
| | | | | | | | |--- feature_1 <= 3.05
| | | | | | | | | |--- value: [1.40]
| | | | | | | | |--- feature_1 > 3.05
| | | | | | | | | |--- feature_1 <= 3.15
| | | | | | | | | | |--- value: [1.50]
| | | | | | | | | |--- feature_1 > 3.15
| | | | | | | | | | |--- value: [1.40]
| | | | | |--- feature_1 > 3.25
| | | | | | |--- value: [1.60]
| |--- feature_2 > 4.75
| | |--- feature_2 <= 5.05
| | | |--- feature_0 <= 6.75
| | | | |--- feature_0 <= 5.80
| | | | | |--- value: [2.00]
| | | | |--- feature_0 > 5.80
| | | | | |--- feature_1 <= 2.35
| | | | | | |--- value: [1.50]
| | | | | |--- feature_1 > 2.35
| | | | | | |--- feature_0 <= 6.25
| | | | | | | |--- value: [1.80]
| | | | | | |--- feature_0 > 6.25
| | | | | | | |--- feature_2 <= 4.95
| | | | | | | | |--- feature_1 <= 2.60
| | | | | | | | | |--- value: [1.50]
| | | | | | | | |--- feature_1 > 2.60
| | | | | | | | | |--- value: [1.80]
| | | | | | | |--- feature_2 > 4.95
| | | | | | | | |--- feature_0 <= 6.50
| | | | | | | | | |--- value: [1.90]
| | | | | | | | |--- feature_0 > 6.50
| | | | | | | | | |--- value: [1.70]
| | | |--- feature_0 > 6.75
| | | | |--- feature_1 <= 2.95
| | | | | |--- value: [1.40]
| | | | |--- feature_1 > 2.95
| | | | | |--- value: [1.50]
| | |--- feature_2 > 5.05
| | | |--- feature_1 <= 3.05
| | | | |--- feature_0 <= 6.35
| | | | | |--- feature_0 <= 5.85
| | | | | | |--- feature_1 <= 2.75
| | | | | | | |--- value: [1.90]
| | | | | | |--- feature_1 > 2.75
| | | | | | | |--- value: [2.40]
| | | | | |--- feature_0 > 5.85
| | | | | | |--- feature_1 <= 2.85
| | | | | | | |--- feature_2 <= 5.35
| | | | | | | | |--- feature_0 <= 6.15
| | | | | | | | | |--- value: [1.60]
| | | | | | | | |--- feature_0 > 6.15
| | | | | | | | | |--- value: [1.50]
| | | | | | | |--- feature_2 > 5.35
| | | | | | | | |--- value: [1.40]
| | | | | | |--- feature_1 > 2.85
| | | | | | | |--- feature_0 <= 6.10
| | | | | | | | |--- value: [1.80]
| | | | | | | |--- feature_0 > 6.10
| | | | | | | | |--- value: [1.80]
| | | | |--- feature_0 > 6.35
| | | | | |--- feature_0 <= 7.50
| | | | | | |--- feature_0 <= 7.15
| | | | | | | |--- feature_1 <= 2.75
| | | | | | | | |--- feature_2 <= 5.55
| | | | | | | | | |--- value: [1.90]
| | | | | | | | |--- feature_2 > 5.55
| | | | | | | | | |--- value: [1.80]
| | | | | | | |--- feature_1 > 2.75
| | | | | | | | |--- feature_0 <= 6.60
| | | | | | | | | |--- feature_2 <= 5.55
| | | | | | | | | | |--- feature_2 <= 5.35
| | | | | | | | | | | |--- value: [2.00]
| | | | | | | | | | |--- feature_2 > 5.35
| | | | | | | | | | | |--- value: [1.80]
| | | | | | | | | |--- feature_2 > 5.55
| | | | | | | | | | |--- feature_2 <= 5.70
| | | | | | | | | | | |--- value: [2.15]
| | | | | | | | | | |--- feature_2 > 5.70
| | | | | | | | | | | |--- value: [2.20]
| | | | | | | | |--- feature_0 > 6.60
| | | | | | | | | |--- feature_0 <= 6.75
| | | | | | | | | | |--- value: [2.30]
| | | | | | | | | |--- feature_0 > 6.75
| | | | | | | | | | |--- value: [2.10]
| | | | | | |--- feature_0 > 7.15
| | | | | | | |--- feature_2 <= 5.95
| | | | | | | | |--- value: [1.60]
| | | | | | | |--- feature_2 > 5.95
| | | | | | | | |--- feature_1 <= 2.85
| | | | | | | | | |--- value: [1.90]
| | | | | | | | |--- feature_1 > 2.85
| | | | | | | | | |--- value: [1.80]
| | | | | |--- feature_0 > 7.50
| | | | | | |--- feature_2 <= 6.80
| | | | | | | |--- feature_2 <= 6.35
| | | | | | | | |--- value: [2.30]
| | | | | | | |--- feature_2 > 6.35
| | | | | | | | |--- feature_1 <= 2.90
| | | | | | | | | |--- value: [2.00]
| | | | | | | | |--- feature_1 > 2.90
| | | | | | | | | |--- value: [2.10]
| | | | | | |--- feature_2 > 6.80
| | | | | | | |--- value: [2.30]
| | | |--- feature_1 > 3.05
| | | | |--- feature_1 <= 3.25
| | | | | |--- feature_0 <= 7.05
| | | | | | |--- feature_0 <= 6.60
| | | | | | | |--- feature_1 <= 3.15
| | | | | | | | |--- value: [1.80]
| | | | | | | |--- feature_1 > 3.15
| | | | | | | | |--- feature_0 <= 6.45
| | | | | | | | | |--- value: [2.30]
| | | | | | | | |--- feature_0 > 6.45
| | | | | | | | | |--- value: [2.00]
| | | | | | |--- feature_0 > 6.60
| | | | | | | |--- feature_2 <= 5.50
| | | | | | | | |--- feature_2 <= 5.25
| | | | | | | | | |--- value: [2.30]
| | | | | | | | |--- feature_2 > 5.25
| | | | | | | | | |--- value: [2.10]
| | | | | | | |--- feature_2 > 5.50
| | | | | | | | |--- feature_2 <= 5.65
| | | | | | | | | |--- value: [2.40]
| | | | | | | | |--- feature_2 > 5.65
| | | | | | | | | |--- value: [2.30]
| | | | | |--- feature_0 > 7.05
| | | | | | |--- value: [1.80]
| | | | |--- feature_1 > 3.25
| | | | | |--- feature_2 <= 6.25
| | | | | | |--- feature_2 <= 5.85
| | | | | | | |--- feature_2 <= 5.65
| | | | | | | | |--- feature_2 <= 5.50
| | | | | | | | | |--- value: [2.30]
| | | | | | | | |--- feature_2 > 5.50
| | | | | | | | | |--- value: [2.40]
| | | | | | | |--- feature_2 > 5.65
| | | | | | | | |--- value: [2.30]
| | | | | | |--- feature_2 > 5.85
| | | | | | | |--- value: [2.50]
| | | | | |--- feature_2 > 6.25
| | | | | | |--- feature_0 <= 7.80
| | | | | | | |--- value: [2.20]
| | | | | | |--- feature_0 > 7.80
| | | | | | | |--- value: [2.00]
===预测结果======
原y 预测y
0 0.2 0.25
1 0.2 0.20
2 1.4 1.40
3 1.5 1.50
4 2.5 2.50
5 1.9 1.90