最小二乘法的python的实现

最小二乘法的python的实现

偶然一次做项目的时候,需要对生长模型进行建模,这就意味着需要对其中的参数进行求解,其中最常见的也就是最小二乘法了,本文将其模型建立的过程的代码展示如下

第一步:导入关键的库


import pandas as pd
import numpy as np
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt

建立一个逻辑生长模型的类

1,首先初始化定义的类

    def __init__(self, cases):
        '''
        Initializes Object

        Args: 1-D array of cases at each time step
        '''
        self.parameters = np.random.exponential(size = 3)
        print("Parameter Initialization")
        print(self.parameters)
        self.x = np.array([i for i in range(len(cases))])
        self.y = np.array(cases)

2,定义需要拟合的方程

def logistic(self, t, a, b, c):
        '''
        Logistic function for training

        Args: The time of the logistic prediction and the parameters
        Returns: Output of logistic function
        '''
        return c / (1 + a * np.exp(-b*t))

3,curve_fit最小二乘法进行参数的求解

def trainLogistic(self):
        '''
        Trains logistic growth model
        '''
        bounds = (0, [1e10, 10, 1e10])
        self.parameters, covariance = curve_fit(self.logistic, self.x, self.y, bounds=bounds, p0=self.parameters)

4,定义预测的函数,将求解后的方程带入进行求解

def predict(self, t):
        '''
        Logistic function for graphing and predictions

        Args: The time of the logistic prediction
        Returns: Output of logistic function
        '''
        return self.parameters[2] / (1 + self.parameters[0] * np.exp(-self.parameters[1]*t))

5,绘制求解的结果

def graph(self):
        '''
        Graphs the data with logistic model
        '''
        plt.scatter(self.x, self.y)
        predictArr = np.vectorize(self.predict)
        graphX = np.append(self.x, [i for i in range(len(self.x), len(self.x)*2)])
        numOfDays = 0
        for x in range(len(graphX[len(self.x):])):
            if self.predict(graphX[len(self.x) + x]) > self.parameters[2]*(0.999):
                numOfDays = x + 1
                break
        if (numOfDays == 0):
            numOfDays = "More Than " + str(len(self.x))
        plt.plot(graphX, predictArr(graphX))
        plt.title('Logistic Model Predictions | Max at ' + str("%.1f" % self.parameters[2]) + "\nReached in " + str(numOfDays) + " Years")
        plt.ylabel("Number")
        plt.xlabel("Years")
        print("Final Parameters")
        print(self.parameters)
        plt.show()

全部的代码如下所示:


class LogisticModel():

    def __init__(self, cases):
        '''
        Initializes Object

        Args: 1-D array of cases at each time step
        '''
        self.parameters = np.random.exponential(size = 3)
        print("Parameter Initialization")
        print(self.parameters)
        self.x = np.array([i for i in range(len(cases))])
        self.y = np.array(cases)

    def logistic(self, t, a, b, c):
        '''
        Logistic function for training

        Args: The time of the logistic prediction and the parameters
        Returns: Output of logistic function
        '''
        return c / (1 + a * np.exp(-b*t))

    def trainLogistic(self):
        '''
        Trains logistic growth model
        '''
        bounds = (0, [1e10, 10, 1e10])
        self.parameters, covariance = curve_fit(self.logistic, self.x, self.y, bounds=bounds, p0=self.parameters)

    def predict(self, t):
        '''
        Logistic function for graphing and predictions

        Args: The time of the logistic prediction
        Returns: Output of logistic function
        '''
        return self.parameters[2] / (1 + self.parameters[0] * np.exp(-self.parameters[1]*t))

    def graph(self):
        '''
        Graphs the data with logistic model
        '''
        plt.scatter(self.x, self.y)
        predictArr = np.vectorize(self.predict)
        graphX = np.append(self.x, [i for i in range(len(self.x), len(self.x)*2)])
        numOfDays = 0
        for x in range(len(graphX[len(self.x):])):
            if self.predict(graphX[len(self.x) + x]) > self.parameters[2]*(0.999):
                numOfDays = x + 1
                break
        if (numOfDays == 0):
            numOfDays = "More Than " + str(len(self.x))
        plt.plot(graphX, predictArr(graphX))
        plt.title('Logistic Model Predictions | Max at ' + str("%.1f" % self.parameters[2]) + "\nReached in " + str(numOfDays) + " Years")
        plt.ylabel("Number")
        plt.xlabel("Years")
        print("Final Parameters")
        print(self.parameters)
        plt.show()

  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值