import numpy as np from sklearn.tree import DecisionTreeRegressor import matplotlib.pyplot as plt #创建一个随机数据集 rng = np.random.RandomState(1) X = np.sort(5*rng.rand(80,1),axis = 0) y = np.sin(X).ravel() y[::5] += 3 * (0.5 - rng.rand(16)) #训练三种不同最大深度的回归树模型,并拟合 clf1 = DecisionTreeRegressor(max_depth=2).fit(X,y) clf2 = DecisionTreeRegressor(max_depth=4).fit(X,y) clf3 = DecisionTreeRegressor(max_depth=16).fit(X,y) #利用回归模型预测 X_test = np.arange(0.0,5.0,0.01)[:,np.newaxis] y1 = clf1.predict(X_test) y2 = clf2.predict(X_test) y3 = clf3.predict(X_test) #绘制散点图和回归曲线 plt.figure() plt.scatter(X,y,c = 'y',label = 'data') plt.plot(X_test,y1,c = 'g',label = 'max_depth = 2',linewidth = 2) plt.plot(X_test,y2,c = 'r',label = 'max_depth = 4',linewidth = 2) plt.plot(X_test,y3,c = 'b',label = 'max_depth = 16',linewidth = 2) plt.xlabel('data') plt.ylabel('target') plt.title('Decision Tree Regression') plt.legend() plt.show()
决策树回归算法
最新推荐文章于 2024-07-02 14:27:14 发布