线性回归(scikit-learn 实战)

# 线性回归(skit-learn 实战) [线性回归API](http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html#sklearn.linear_model.LinearRegression) ## 引入包
import csv
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
import seaborn as sns
## 数据读取与特征选择 广告投入与销售额数据,共有5列,分别为id,电视广告投入、无限广播投入、报纸投入和销售额。 [下载](https://pan.baidu.com/s/13Foo6dEf2aYqRr2v4NGrfA)
path = './Advertising.csv'
# pandas读入
data = pd.read_csv(path)    # TV、Radio、Newspaper、Sales
data.head()
Unnamed: 0TVRadioNewspaperSales
01230.137.869.222.1
1244.539.345.110.4
2317.245.969.39.3
34151.541.358.518.5
45180.810.858.412.9
# 用pairplot画图,观察Sales与各特征之间的关系
sns.pairplot(data, x_vars=['TV','Radio','Newspaper'], y_vars='Sales')

# 从上图可以看出,Sales与TV具有较强的线性关系,仅选用TV一个feature
x = data[['TV']]
y = data['Sales']
# 划分训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(x, y, random_state=1)

模型训练

# 模型训练
linreg = LinearRegression()
model = linreg.fit(x_train, y_train)
model
LinearRegression(copy_X=True, fit_intercept=True, n_jobs=1, normalize=False)
# 系数
model.coef_
array([0.04802945])
# 截距
linreg.intercept_
6.91197261886872

精度评定

y_hat = linreg.predict(np.array(x_test))

均方误差

MSE=1NN(yy^)2 M S E = 1 N ∑ N ( y − y ^ ) 2

mse = mean_squared_error(y_test, y_hat)
mse
10.310069587813155

R2 R 2

  • 样本总平方和TSS(Total Sum of Squares): TSS=(yiy¯¯¯)2 T S S = ∑ ( y i − y ¯ ) 2
  • 残差平方和RSS(Residual Sum of Squares): RSS=(yiy^)2 R S S = ∑ ( y i − y ^ ) 2
  • R2=1RSSTSS R 2 = 1 − R S S T S S
    • R2 R 2 越大,拟合效果越好
    • R2 R 2 最优值为1,若模型拟合效果较差,可能为负
    • 若预测值恒为样本均值, R2 R 2 = 0
score = model.score(x_test,y_test)
score
0.5590828580007852

可视化结果

t = np.arange(len(x_test))
plt.plot(t, y_test, 'r-', linewidth=2, label='Test')
plt.plot(t, y_hat, 'g-', linewidth=2, label='Predict')
plt.legend(loc='upper right')
plt.grid()

plt.scatter(x_test, y_test,  color='red')
plt.plot(x_test, y_hat, color='green', linewidth=3)

这里写图片描述

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值