Scipy中最小二乘函数leastsq()简单使用

本篇的主要内容:

  • 介绍Scipy中optimize模块的leastsq函数

最近接触到了Scipy中optimize模块的一些函数,optimize模块中提供了很多数值优化算法,其中,最小二乘法可以说是最经典的数值优化技术了, 通过最小化误差的平方来寻找最符合数据的曲线。在optimize模块中,使用leastsq()函数可以很快速地使用最小二乘法对数据进行拟合。

首先来看leastsq()函数地调用格式:

leastsq(func, 
        x0,
        args=(),
        Dfun=None,
        full_output=0,
        col_deriv=0,
        ftol=1.49012e-08,
        xtol=1.49012e-08,
        gtol=0.0,
        maxfev=0,
        epsfcn=0.0,
        factor=100,
        diag=None,
        warning=True)

参数还是非常多的,一般来说,我们只需要前三个参数就够了他们的作用分别是:

  • func:误差函数
  • x0:表示函数的参数
  • args()表示数据点

举个例子:
这里要进行拟合的数据点都分布在这条正弦曲线附近:

def func(x):
    return 2*np.sin(2*np.pi*x)

然后定义误差函数,所谓误差就是指我们拟合的曲线的值对应真实值的差:

def residuals(p, x, y):
    fun = np.poly1d(p)    # poly1d()函数可以按照输入的列表p返回一个多项式函数
    return y - fun(x)   # 返回真实值 与我们拟合的曲线上对应的值的差

这里设计了一个poly1d()函数,关于这个函数,简单理解下就是输入一个列表,返回以这个列表中的值为参数的多项式,例如:

输入:[1,2,3]
返回:x^2 + 2x + 3
多项式的次数是从0开始记的,要注意这个地方

下面定义关于拟合的曲线的函数:

# 拟合函数
def fitting(p):
    pars = np.random.rand(p+1)  # 生成p+1个随机数的列表,这样poly1d函数返回的多项式次数就是p
    r = leastsq(residuals, pars, args=(X, Y))   # 三个参数:误差函数、函数参数列表、数据点
    return r

注释里的内容就是要注意的地方,由于会多次调用拟合,多以写成了函数的形式,这里传入的p是一个数字,表示我们想要得到拟合曲线的次数,比如我想针对这些数据点得到一条3次的曲线,就调用p=3类似,注意这里leastsq()函数的返回值,这里的返回值保存的是拟合的曲线的信息,如果打印这里的r,就会发现返回了一个truple,其中第一维是一个列表,保存的是拟合的曲线的参数,所以要注意如何获得这些参数。
接下来定义一下我们要进行拟合的数据点,这里定义了10个:

# 要进行拟合的数据点
X = np.linspace(0, 1, 10)
Y = [np.random.normal(0, 0.1)+num for num in func(X)]  # 添加噪声

# 方便绘制曲线,所以创建多一些点
x_ = np.linspace(0, 1, 100)
y_ = func(x_)

调用拟合函数,并进行绘图:

fit_pars = fitting(3)[0]    # 注意返回值中的第一行才是拟合曲线的参数列表

plt.plot(x_, y_, label='real line')
plt.scatter(X, Y, label='real points')
plt.plot(x_, np.poly1d(fit_pars)(x_), label='fitting line')
plt.legend()
plt.show()

p=3的时候的图像:
在这里插入图片描述
当然,这里我直接传入p=3,也就是建立3次的曲线对数据点进行拟合,如果传入的p=1的时候,图像如下:
在这里插入图片描述
如果p=2,则是:
在这里插入图片描述
可以看到没有变化,也就是说没办法找到一条二次曲线,使得二次误差少于上面的一次曲线了。
完整代码如下:

import numpy as np
from scipy.optimize import leastsq
import matplotlib.pyplot as plt

# 数据点分布在这条曲线附近
def func(x):
    return 2*np.sin(2*np.pi*x)

# 误差函数, 计算拟合曲线与真实数据点之间的差 ,作为leastsq函数的输入
def residuals(p, x, y):
    fun = np.poly1d(p)    # poly1d()函数可以按照输入的列表p返回一个多项式函数
    return y - fun(x)

# 拟合函数
def fitting(p):
    pars = np.random.rand(p+1)  # 生成p+1个随机数的列表,这样poly1d函数返回的多项式次数就是p
    r = leastsq(residuals, pars, args=(X, Y))   # 三个参数:误差函数、函数参数列表、数据点
    return r

# 要进行拟合的数据点
X = np.linspace(0, 1, 10)
Y = [np.random.normal(0, 0.1)+num for num in func(X)]  # 添加噪声

# 方便绘制曲线,所以创建
x_ = np.linspace(0, 1, 100)
y_ = func(x_)

# print(fitting(3))   可以看一下返回的是什么
fit_pars = fitting(3)[0]

plt.plot(x_, y_, label='real line')
plt.scatter(X, Y, label='real points')
plt.plot(x_, np.poly1d(fit_pars)(x_), label='fitting line')
plt.legend()
plt.show()

以上~

  • 40
    点赞
  • 101
    收藏
    觉得还不错? 一键收藏
  • 8
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值