demo:数据集
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 9 20:05:51 2018
@author: lisir
"""
import csv
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from pprint import pprint
if __name__ == "__main__":
csv_path = "/home/lisir/test/c_test/ml/tmp/10.Regression/Advertising.csv"
# read data method one
f = open(csv_path)
x = []
y = []
for i, d in enumerate(f):
if i == 0:
continue
d = d.strip()
if not d:
continue
# split data by ,
data = list(map(float, d.split(',')))
x.append(data[1:3])
y.append(data[-1])
pprint(x)
pprint(y)
print("=========================")
x = np.array(x)
y = np.array(y)
print(x)
print(y)
print("==========================")
plt.figure(figsize=(6, 9))
plt.subplot(211)
plt.plot(x[:,0], y, 'ro')
plt.title("TV")
plt.grid()
plt.subplot(212)
plt.plot(x[:,1], y, 'g^')
plt.title("Radio")
plt.grid()
# =============================================================================
# plt.subplot(313)
# plt.plot(x[:,2], y, 'b*')
# plt.title("Nespaper")
# plt.grid()
# =============================================================================
plt.tight_layout()
plt.show()
# feed data
x_train, x_test, y_train, y_test = train_test_split(
x, y, train_size=0.8, random_state=1)
print("---------------------------")
print(x_train)
print(y_train)
print("----------------------------")
# read model
linearRegression = LinearRegression()
# gradien descent method to fit data parameters
model = linearRegression.fit(x_train, y_train)
print("=============================")
print(model)
print(linearRegression.coef_)
print(linearRegression.intercept_)
print("=============================")
y_hat = linearRegression.predict(np.array(x_test))
mse = np.average((y_hat - np.array(y_test)) ** 2)
rmse = np.sqrt(mse)
print(mse)
print(rmse)
t = np.arange(len(x_test))
plt.plot(t, y_test, 'r-', Linewidth=2, label="real data")
plt.plot(t, y_hat, 'g-', Linewidth=2, label="predict data")
plt.legend(loc='upper right')
plt.title("linear regression sales", fontsize=18)
plt.grid()
plt.show()
采取三个行业的特征进行分析预测:
采取前面两个线性特征比较直观的特征进行试验: