Python 调用 C 接口实现三次自然样条插值

C 语言接口需要的动态链接库源码:

三点三次自然样条插值 — Home v1.2023.11 文档 (illusionna.readthedocs.io)icon-default.png?t=N7T8https://illusionna.readthedocs.io/zh/latest/projects/Mathematics/Numerical%20Analysis/%E4%B8%89%E7%82%B9%E4%B8%89%E6%AC%A1%E8%87%AA%E7%84%B6%E6%A0%B7%E6%9D%A1%E6%8F%92%E5%80%BC/Spline.html#id2

Python 代码:

'''
# System --> Windows & Python3.8.0
# File ----> NaturalCubicSpline.py
# Author --> Illusionna
# Create --> 2024/2/22 21:08:56
'''
# -*- Encoding: UTF-8 -*-


import os
import ctypes   # 先编译生成动态链接库: gcc --share -o spline.dll NaturalCubicSpline.c
import numpy as np
from bisect import bisect_left
import matplotlib.pyplot as plt
from scipy.interpolate import CubicSpline   # 与官方标准库比对测试用.

def cls() -> None:
    os.system('cls')
    global root
    root = os.getcwd()
cls()


class NATURAL_CUBIC_SPLINE:
    def __init__(self, X:list, Y:list) -> None:
        self.__X = X
        self.__Y = Y
        self.__pos = False
        if len(X) == len(Y):
            lib = ctypes.CDLL(root + './spline.dll')
            self.__Spline = lib.__getattr__('Spline')
            self.__Spline.restype = ctypes.POINTER(ctypes.POINTER(ctypes.c_double))
            self.__Spline.argtypes = [
                ctypes.POINTER(ctypes.c_double),
                ctypes.POINTER(ctypes.c_double),
                ctypes.c_int
            ]
        else:
            assert print(f'\033[031m数组 X 长度: {len(X)}, 数组 Y 长度 {len(Y)}, 长度不一致.\033[0m')

    def Coefficients(self) -> list:
        ptr = ctypes.c_double * len(self.__X)
        a = self.__Spline(
            ptr(* self.__X),
            ptr(* self.__Y),
            len(self.__X)
        )
        self.__coefficients = [
            [a[i][j] for j in range(0, 5, 1)]
            for i in range(0, len(self.__X)-1, 1)
        ]
        self.__pos = True
        return self.__coefficients

    def Interpolate(self, x:float) -> 'function':
        def Parameters(period:int) -> tuple:
            xk = self.__coefficients[period][0]
            a = self.__coefficients[period][1]
            b = self.__coefficients[period][2]
            c = self.__coefficients[period][3]
            d = self.__coefficients[period][4]
            return (xk, a, b, c, d)

        def Calculate() -> 'function':
            idx = bisect_left(self.__X, x)
            n = len(self.__X)
            if (idx == 0) | (idx == 1):
                (xk, a, b, c, d) = Parameters(0)
            elif (idx == n) | (idx == n-1):
                (xk, a, b, c, d) = Parameters(-1)
            else:
                (xk, a, b, c, d) = Parameters(idx-1)
            y = lambda x: a + b*(x-xk) + c*(x-xk)**2 + d*(x-xk)**3
            return y(x)

        if self.__pos == True:
            return Calculate()
        else:
            self.Coefficients()
            return Calculate()

        
def Test1() -> None:
    X = [1, 2, 3, 7.23]
    Y = [2, 3, 5, -1.75]
    nodes:list = [1.25, 1.5, 1.75, 2.25, 2.5, 2.75, 4, 5, 6]    # 待预测节点横坐标向量.
    # ---------------------------------------------------------------------------------
    obj = NATURAL_CUBIC_SPLINE(
        X = X,
        Y = Y
    )
    # print(obj.Coefficients())     # 打印样条插值多项式系数.
    f = lambda x: obj.Interpolate(x)
    print(f'\033[036mIllusionna 三次自然样条插值\033[0m预测结果:\n{list(map(f, nodes))}')
    # ---------------------------------------------------------------------------------
    cs = CubicSpline(
        x = X,
        y = Y
    )
    print(f'\033[033mScipy 库三次样条插值\033[0m预测结果:\n{cs(nodes).tolist()}')

def Test2() -> None:
    X = np.linspace(0, 20, 100)
    Y = np.sin(X)
    nodes:list = [1, 2, 3, 4, 6, 7, 8, 9, 11, 13, 14, 16, 18, 19]    # 待预测节点横坐标向量.
    # ---------------------------------------------------------------------------------
    obj = NATURAL_CUBIC_SPLINE(
        X = X,
        Y = Y
    )
    f = lambda x: obj.Interpolate(x)
    print(f'\033[036mIllusionna 三次自然样条插值\033[0m预测结果:\n{list(map(f, nodes))}')
    # ---------------------------------------------------------------------------------
    cs = CubicSpline(
        x = X,
        y = Y
    )
    print(f'\033[033mScipy 库三次样条插值\033[0m预测结果:\n{cs(nodes).tolist()}')
    # ---------------------------------------------------------------------------------
    plt.scatter(X, Y, s=10)
    plt.plot(nodes, list(map(f, nodes)), 'r-')
    plt.plot(nodes, cs(nodes).tolist(), 'g--')
    plt.legend(['Observations', 'Illusionna predicts nodes', 'Scipy predicts nodes'])
    plt.show()


if __name__ == '__main__':
    Test1()
    # Test2()

  • 11
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值