advcase.py-20180704

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jul  3 17:43:53 2018

@author: vicky
"""

# 导入第三方包
import pandas as pd
import numpy as np
import statsmodels.formula.api as smf
from sklearn.cross_validation import train_test_split
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt
import seaborn as sns

#数据集中各变量的描述性统计分析
#data = pd.read_csv('C:\Users\wenyun.wxw\Desktop\Advertising.csv')
data.describe()

#对比三个因子与y的散点图
sns.pairplot(data, x_vars=["TV", "Radio", "Newspaper"], y_vars="Sales",size=5,aspect=0.7, kind="reg")
sns.plt.show()

#相关系数
corr=data.corr()
#相关图
plt.imshow(data.corr(), cmap=plt.cm.Blues, interpolation='nearest')
plt.colorbar()
tick_marks = [i for i in range(len(data.columns))]
plt.xticks(tick_marks, data.columns, rotation='vertical')
plt.yticks(tick_marks, data.columns)

#训练集和测试集二八分
Train,Test = train_test_split(data, train_size = 0.8, random_state=1234)

#建线性回归模型
fit = smf.ols('Sales~TV+Radio+Newspaper', data = Train).fit()
fit.summary()

#去掉newspaper
fit2 = smf.ols('Sales~TV+Radio', data = Train.drop('Newspaper', axis = 1)).fit()
fit2.summary()

#加交互作用
fit3 = smf.ols('Sales~TV+Radio+TV:Radio', data = Train.drop('Newspaper', axis = 1)).fit()
fit3.summary()

pred = fit.predict(exog = Test)
pred2 = fit2.predict(exog = Test.drop('Newspaper', axis = 1))
pred3 = fit3.predict(exog = Test.drop('Newspaper', axis = 1))

#均方根误差
RMSE = np.sqrt(mean_squared_error(Test.Sales, pred))
RMSE2 = np.sqrt(mean_squared_error(Test.Sales, pred2))
RMSE3 = np.sqrt(mean_squared_error(Test.Sales, pred3))
print('RMES=%.4f\n' %RMSE)
print('RMES=%.4f\n' %RMSE2)
print('RMES=%.4f\n' %RMSE3)

#画真实值与预测值的对比图
plt.style.use('ggplot')
plt.scatter(Test.Sales, pred,c='b',label = 'Observations')
plt.plot([Test.Sales.min(), Test.Sales.max()], [pred.min(), pred.max()], 'r--', lw=2, label = 'Fitted line')
plt.title('Real Values VS. Predict Values')
plt.xlabel('Real Values')
plt.ylabel('Prediction Values')
plt.legend(loc = 'upper left')
plt.show()

#画残差图
plt.style.use('ggplot')
plt.scatter(pred,Test.Sales-pred,c='b',label = 'Residuals')
#plt.plot([pred.min(), pred.max()],[(Test.Sales-pred).min(), (Test.Sales-pred).max()],  'r--', lw=2, label = 'Fitted line')
plt.title('Residual Plot')
plt.xlabel('Fitted Values')
plt.ylabel('Residuals')
plt.legend(loc = 'upper left')
plt.show()

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值