线性回归(一)

一、简单线性回归

1.1 什么是简单线性回归

所谓简单,是指只有一个样本特征,即只有一个自变量;
所谓线性,是指方程是线性的;
所谓回归,是指用方程来模拟变量之间是如何关联的;
简单线性回归,其思想简单,实现容易(与其背后强大的数学性质相关)。同时也是许多强大的非线性模型(多项式回归、逻辑回归、SVM)的基础。并且其结果具有很好的可解释性。

1.2 求解思路

回归重要的任务就是拟合,找到最佳的拟合规律是最终的目标。也就是说,我们需要一条直线,最大程度的拟合样本特征和样本数据标记之间的关系。
反映到数学公式中,就是一次函数。在二维平面中,这条直线的方程就是 y = ax + b
假设我们找到了最佳拟合的直线方程:y = ax + b
在这里插入图片描述

1.3 一种基本推到思路

这是一个典型的最小二乘法问题(最小化误差的平方)
通过最小二乘法可以求出a、b的表达式:
在这里插入图片描述

1.4 代码实现

import numpy as np


class SimpleLinearRegression1:

    def __init__(self):
        """初始化 Simple Linear Regression 模型"""
        self.a_ = None
        self.b_ = None

    def fit(self, x_train, y_train):
        """根据训练数据集x_train,y_train训练 Simple Linear Regression 模型"""
        assert x_train.ndim == 1, "Simple Linear Regressor can only solve single feature training data"
        assert len(x_train) == len(y_train), "the size of x_train must be equal to the of y_train"

        x_mean = np.mean(x_train)
        y_mean = np.mean(y_train)

        num = 0.0  # 分子
        d = 0.0  # 分母
        for x_i, y_i in zip(x_train, y_train):
            num += (x_i - x_mean) * (y_i - y_mean)
            d += (x_i - x_mean) ** 2

        self.a_ = num / d
        self.b_ = y_mean - self.a_ * x_mean

        return self

    def predict(self, x_predict):
        """给定待预测数据集x_predict, 返回x_predict的结果向量"""
        assert x_predict.ndim == 1, "Simple Linear Regressor can only solve single feature training data"
        assert self.a_ is not None and self.b_ is not None, "must fit before predict!"

        return np.array([self._predict(x) for x in x_predict])

    def _predict(self, x_single):
        """给定单个待预测数据x_single, 返回x_single的预测结果值"""
        return self.a_ * x_single + self.b_

    def __repr__(self):
        return "SimpleLinearRegression1()"

代码调用

import numpy as np
import matplotlib.pyplot as plt

x = np.array([1., 2., 3., 4., 5.])
y = np.array([1., 3., 2., 3., 5.])

from playML.SimpleLinearRegression import SimpleLinearRegression1

reg1 = SimpleLinearRegression1()
reg1.fit(x, y)
y_hat1 = reg1.predict(x)
plt.scatter(x, y)
plt.plot(x, y_hat1, color='r')
plt.axis([0, 6, 0, 6])
plt.show()
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值