包含全部示例的代码仓库见GIthub
1 导入库
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures
from sklearn.metrics import mean_squared_error
2 数据准备
x = np.linspace(0,10,20)
y = x**3 + np.random.rand(20)*160 + 10
plt.scatter(x, y)
3 模型构建
3.1 欠拟合
model = LinearRegression()
model.fit(x.reshape(-1,1), y)
y_pred = model.predict(x.reshape(-1,1))
plt.scatter(x, y)
plt.plot(x, y_pred, c='r')
3.2 正常拟合
q4 = PolynomialFeatures(degree=4)
x4 = q4.fit_transform(x.reshape(-1,1))
model4 = LinearRegression()
model4.fit(x4, y)
y_pred4 = model4.predict(x4)
plt.scatter(x, y)
plt.plot(x, y_pred4, c='r')
3.3 过拟合
q20 = PolynomialFeatures(degree=20)
x20 = q20.fit_transform(x.reshape(-1,1))
model20 = LinearRegression()
model20.fit(x20, y)
y_pred20 = model20.predict(x20)
plt.scatter(x, y)
plt.plot(x, y_pred20, c='r')
x2 = np.linspace(0,20,30)
y2 = x2**3 + np.random.rand(30)*160 + 10
用新的数据集进行模型预测
q20 = PolynomialFeatures(degree=20)
x22 = q20.fit_transform(x2.reshape(-1,1))
y_pred22 = model20.predict(x22)
plt.scatter(x, y)
plt.plot(x2, y_pred22, c='r')
4 模型测试
欠拟合
mean_squared_error(model.predict(x.reshape(-1,1)),y)
# output
14478.676863285145
正常拟合
mean_squared_error(model4.predict(x4),y)
# output
1757.2597799281452
过拟合
mean_squared_error(model20.predict(x20),y)
# output
1208.419154080902