目录
多项式回归
假如我们不是要找直线(或超平面),而是一个需要找到一个用多项式所表示的曲线(或超曲面)。例如二次曲线:y=at^2 + bt + c
散点分布时,我们也可以清楚看到,有时我们找一条曲线,它的拟合度会更高
多项式回归可以写成下面的这种形式:
多项式代码实现
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures
%matplotlib inline
#载入数据
data = np.genfromtxt("job.csv",delimiter=',')
x_data = data[1:,1]
y_data = data[1:,2]
plt.scatter(x_data,y_data)
plt.show()
x_data = x_data[:,np.newaxis]
y_data = y_data[:,np.newaxis]
#创建并拟合模型
model = LinearRegression()
model.fit(x_data,y_data)
#画图
plt.plot(x_data,y_data,'b.')
plt.plot(x_data,model.predict(x_data),'r')
plt.show()
以上代码是我们很熟悉的代码了,还没用到多项式回归,画出来的图如下 :
下边的代码是多项式回归:
#定义多项式回归,degree的值可以调节多项式的特征
poly_reg = PolynomialFeatures(degree = 10)
#特征处理
x_poly = poly_reg.fit_transform(x_data)#degree = 3 x_poly = x^0 x^1 x^2 x^3
#定义回归模型
lin_reg = LinearRegression()
#训练模型
lin_reg.fit(x_poly,y_data)
训练模型结束后,我们可以将图形画出来:
#画图
plt.plot(x_data,y_data,'b.')
x_test = np.linspace(1,10,100)
x_test = x_test[:,np.newaxis]
plt.plot(x_test,lin_reg.predict(poly_reg.fit_transform(x_test)),c='r') #做预测的时候要传入已经处理过的值
plt.title('Truth or Bluff (Polynomial Regression)')
plt.xlabel('Position level')
plt.ylabel('Salary')
plt.show()
对比上边的图,我们可以清楚的看到这张图的拟合度更高!