python线性回归算法简介_线性回归算法Python小实现

demo:数据集

#!/usr/bin/env python3

# -*- coding: utf-8 -*-

"""

Created on Fri Mar 9 20:05:51 2018

@author: lisir

"""

import csv

import numpy as np

import matplotlib.pyplot as plt

import matplotlib as mpl

import pandas as pd

from sklearn.model_selection import train_test_split

from sklearn.linear_model import LinearRegression

from pprint import pprint

if __name__ == "__main__":

csv_path = "/home/lisir/test/c_test/ml/tmp/10.Regression/Advertising.csv"

# read data method one

f = open(csv_path)

x = []

y = []

for i, d in enumerate(f):

if i == 0:

continue

d = d.strip()

if not d:

continue

# split data by ,

data = list(map(float, d.split(',')))

x.append(data[1:3])

y.append(data[-1])

pprint(x)

pprint(y)

print("=========================")

x = np.array(x)

y = np.array(y)

print(x)

print(y)

print("==========================")

plt.figure(figsize=(6, 9))

plt.subplot(211)

plt.plot(x[:,0], y, 'ro')

plt.title("TV")

plt.grid()

plt.subplot(212)

plt.plot(x[:,1], y, 'g^')

plt.title("Radio")

plt.grid()

# =============================================================================

# plt.subplot(313)

# plt.plot(x[:,2], y, 'b*')

# plt.title("Nespaper")

# plt.grid()

# =============================================================================

plt.tight_layout()

plt.show()

# feed data

x_train, x_test, y_train, y_test = train_test_split(

x, y, train_size=0.8, random_state=1)

print("---------------------------")

print(x_train)

print(y_train)

print("----------------------------")

# read model

linearRegression = LinearRegression()

# gradien descent method to fit data parameters

model = linearRegression.fit(x_train, y_train)

print("=============================")

print(model)

print(linearRegression.coef_)

print(linearRegression.intercept_)

print("=============================")

y_hat = linearRegression.predict(np.array(x_test))

mse = np.average((y_hat - np.array(y_test)) ** 2)

rmse = np.sqrt(mse)

print(mse)

print(rmse)

t = np.arange(len(x_test))

plt.plot(t, y_test, 'r-', Linewidth=2, label="real data")

plt.plot(t, y_hat, 'g-', Linewidth=2, label="predict data")

plt.legend(loc='upper right')

plt.title("linear regression sales", fontsize=18)

plt.grid()

plt.show()采取三个行业的特征进行分析预测:

66b3c8075ae2cd8a5b037b24b34582b7.png

采取前面两个线性特征比较直观的特征进行试验:

8975d367c240215059d2b66fcc881e7b.png

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值