回归-多项式回归

1、什么是多项式回归

线性回归适用于数据呈线性分布的回归问题.如果数据样本呈明显非线性分布,线性回归模型就不再适用(下图左),而采用多项式回归可能更好(下图右).例如:

2、模型定义

与线性模型相比,多项式模型引入了高次项,自变量的指数大于1,例如一元二次方程:

y = w_0 + w_1x + w_2x^2

一元三次方程:

y = w_0 + w_1x + w_2x^2 + w_3x ^ 3

推广到一元n次方程:

y = w_0 + w_1x + w_2x^2 + w_3x ^ 3 + ... + w_nx^n

上述表达式可以简化为:

y = \sum_{i=1}^N w_ix^i

3、多项式回归与线性回归的关系

多项式回归可以理解为线性回归的扩展,在线性回归模型中添加了新的特征值.例如,要预测一栋房屋的价格,有x_1, x_2, x_3三个特征值,分别表示房子长、宽、高,则房屋价格可表示为以下线性模型:

y = w_1 x_1 + w_2 x_2 + w_3 x_3 + b

对于房屋价格,也可以用房屋的体积,而不直接使用x_1, x_2, x_3三个特征:

y = w_0 + w_1x + w_2x^2 + w_3x ^ 3

相当于创造了新的特征x, x = 长 * 宽 * 高. 以上两个模型可以解释为:

  • 房屋价格是关于长、宽、高三个特征的线性模型

  • 房屋价格是关于体积的多项式模型

因此,可以将一元n次多项式变换成n元一次线性模型.

4、代码实现

# 多项式回归示例
import numpy as np
import sklearn.linear_model as lm
import sklearn.metrics as sm
import matplotlib.pyplot as mp
import sklearn.pipeline as pl
import sklearn.preprocessing as sp

train_x, train_y = [], []   # 输入、输出样本
with open("poly_sample.txt", "rt") as f:
    for line in f.readlines():
        data = [float(substr) for substr in line.split(",")]
        train_x.append(data[:-1])
        train_y.append(data[-1])

train_x = np.array(train_x)  # 二维数据形式的输入矩阵,一行一样本,一列一特征
train_y = np.array(train_y)  # 一维数组形式的输出序列,每个元素对应一个输入样本
# print(train_x)
# print(train_y)

model = pl.make_pipeline(sp.PolynomialFeatures(3), # 多项式特征扩展,扩展最高次项为3
                         lm.LinearRegression())

# 用已知输入、输出数据集训练回归器
model.fit(train_x, train_y)
# print(model[1].coef_)
# print(model[1].intercept_)

# 根据训练模型预测输出
pred_train_y = model.predict(train_x)

# 评估指标
err4 = sm.r2_score(train_y, pred_train_y)  # R2得分, 范围[0, 1], 分值越大越好
print(err4)

# 在训练集之外构建测试集
test_x = np.linspace(train_x.min(), train_x.max(), 1000)
pre_test_y = model.predict(test_x.reshape(-1, 1)) # 对新样本进行预测

# 可视化回归曲线
mp.figure('Polynomial Regression', facecolor='lightgray')
mp.title('Polynomial Regression', fontsize=20)
mp.xlabel('x', fontsize=14)
mp.ylabel('y', fontsize=14)
mp.tick_params(labelsize=10)
mp.grid(linestyle=':')
mp.scatter(train_x, train_y, c='dodgerblue', alpha=0.8, s=60, label='Sample')

mp.plot(test_x, pre_test_y, c='orangered', label='Regression')

mp.legend()
mp.show()

结果:

 

原创不易,转载请标明出处!

  • 9
    点赞
  • 47
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值