代码
# 决策树用于拟合
import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeRegressor
if __name__ == "__main__":
# 构造数据
N = 100
x = np.random.rand(N) * 6 - 3 # [-3,3)
# print(x.shape) # (100,)
x.sort()
y = np.sin(x) + np.random.randn(N) * 0.05
print(y)
x = x.reshape(-1, 1) # 转置后,得到N个样本,每个样本都是1维的
print(x)
# 决策树分类器
dt = DecisionTreeRegressor(criterion='mse', max_depth=9)
dt.fit(x, y)
x_test = np.linspace(-3, 3, 50).reshape(-1, 1)
y_hat = dt.predict(x_test)
plt.plot(x, y, 'r*', ms=10, label='Actual')
plt.plot(x_test, y_hat, 'g-', linewidth=2, label='Predict')
plt.legend(loc='upper left')
plt.grid()
plt.show()
# 比较决策树的深度影响
depth = [2, 4, 6, 8, 10]
# 颜色
clr = 'rgbmy'
dtr = DecisionTreeRegressor(criterion='mse')
plt.plot(x, y, 'ko', ms=6, label='Actual')
x_test = np.linspace(-3, 3, 50).reshape(-1, 1)
for d, c in zip(depth, clr):
# 设置参数
dtr.set_params(max_depth=d)
dtr.fit(x, y)
y_hat = dtr.predict(x_test)
plt.plot(x_test, y_hat, '-', color=c, linewidth=2, label='Depth=%d' % d)
plt.legend(loc='upper left')
plt.grid(b=True)
plt.show()
(100,)
[-0.13747212 -0.32696352 -0.34299033 -0.37734289 -0.30216829 -0.41908633
-0.42649759 -0.55874875 -0.47470554 -0.50349372 -0.60084058 -0.72667652
-0.88731673 -0.85007184 -0.80980603 -0.89046954 -0.92967645 -1.01708456
-0.96413472 -1.00831613 -1.06009149 -0.98629175 -0.99021064 -0.88084281
-0.90996548 -0.89476142 -0.80952269 -0.83540464 -0.76614234 -0.75365537
-0.51213752 -0.53558931 -0.5158306 -0.51753766 -0.47760662 -0.49621367
-0.35078086 -0.4007496 -0.37787176 -0.35708106 -0.33543894 0.05607983
-0.04710956 0.02358386 0.13753866 0.22134074 0.36428241 0.38151542
0.42788242 0.47056583 0.47299773 0.57728474 0.69424008 0.68866846
0.74362813 0.85661517 0.79570145 0.72801613 0.83298817 0.91378756
0.92111679 1.01043268 0.96942097 0.989228 0.97144073 0.95992022
0.90630972 0.94775525 1.00992384 1.00577511 1.0092611 1.06641845
1.01056367 0.92489214 0.99751525 0.9716967 0.90643779 0.93410205
0.90237971 0.93908154 0.88156985 0.84080906 0.81336031 0.81184513
0.77923751 0.71039144 0.65860142 0.68686109 0.66221666 0.46724569
0.49525938 0.33146802 0.26010888 0.33738618 0.2700388 0.25114123
0.25704015 0.16070012 0.10970704 0.24002726]
[[-2.98291781]
[-2.84917313]
[-2.83638979]
[-2.7860988 ]
[-2.77855502]
[-2.71508228]
[-2.69430568]
[-2.66073423]
[-2.61173817]
[-2.58799644]
[-2.44903871]
[-2.25958656]
[-2.23913379]
[-2.23306881]
[-2.22936331]
[-2.10650054]
[-1.8938904 ]
[-1.7209747 ]
[-1.66178009]
[-1.41848746]
[-1.30166897]
[-1.23206966]
[-1.21439096]
[-1.06802454]
[-1.0338541 ]
[-0.98878396]
[-0.97796016]
[-0.95109667]
[-0.87134449]
[-0.80900001]
[-0.56761021]
[-0.54240145]
[-0.53489081]
[-0.44861858]
[-0.43542907]
[-0.42543931]
[-0.41576859]
[-0.37648098]
[-0.36547507]
[-0.35646849]
[-0.33521726]
[ 0.03031728]
[ 0.03678361]
[ 0.08126219]
[ 0.11222954]
[ 0.18367956]
[ 0.36263569]
[ 0.37936774]
[ 0.39229896]
[ 0.40675994]
[ 0.57247375]
[ 0.59730802]
[ 0.75540724]
[ 0.80358303]
[ 0.8803318 ]
[ 0.92195617]
[ 0.92913085]
[ 0.98071443]
[ 1.08071656]
[ 1.1088603 ]
[ 1.14930987]
[ 1.32788011]
[ 1.34569531]
[ 1.36755187]
[ 1.40643578]
[ 1.43714342]
[ 1.44562189]
[ 1.51583702]
[ 1.53488103]
[ 1.58985047]
[ 1.6181127 ]
[ 1.641521 ]
[ 1.65421212]
[ 1.6710831 ]
[ 1.68895352]
[ 1.76660029]
[ 1.84145428]
[ 1.88972944]
[ 1.96540222]
[ 1.98953008]
[ 1.9968826 ]
[ 2.02694326]
[ 2.09358637]
[ 2.22188409]
[ 2.24214962]
[ 2.38443801]
[ 2.38563411]
[ 2.39967366]
[ 2.45955993]
[ 2.59606735]
[ 2.66426555]
[ 2.7828691 ]
[ 2.78397954]
[ 2.83419346]
[ 2.84315114]
[ 2.89945669]
[ 2.90104871]
[ 2.97192459]
[ 2.97521555]
[ 2.99002727]]