什么是多项式回归
线性回归
多项式回归
把
x
2
x^{2}
x2理解了一个特征,这样依然是一个线性回归的式子。
什么是多项式回归
import numpy as np
import matplotlib.pyplot as plt
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x**2 + x + 2 + np.random.normal(0, 1, 100)
plt.scatter(x, y)
plt.show()
线性回归?
from sklearn.linear_model import LinearRegression
lin_reg = LinearRegression()
lin_reg.fit(X, y)
y_predict = lin_reg.predict(X)
plt.scatter(x, y)
plt.plot(x, y_predict, color='r')
plt.show()
解决方案, 添加一个特征
X2 = np.hstack([X, X**2])
X2.shape
(100, 2)
lin_reg2 = LinearRegression()
lin_reg2.fit(X2, y)
y_predict2 = lin_reg2.predict(X2)
plt.scatter(x, y)
plt.plot(np.sort(x), y_predict2[np.argsort(x)], color='r')
plt.show()
lin_reg2.coef_
array([ 0.99870163, 0.54939125])
lin_reg2.intercept_
1.8855236786516001