KAN 学习 Day2 —— utils.py 与 spline.py 代码解读及测试

KAN学习Day1——模型框架解析及HelloKAN中,我对KAN模型的基本原理进行了简单说明,并将作者团队给出的入门教程hellokan跑了一遍,今天我们直接开始进行源码解读。

目录

一、kan目录

二、utils.py

2.1 导入库和模块

 2.2 逆函数定义

2.3 符号库定义

2.4 create_dataset

2.5 fit_params

2.6 sparse_mask

 2.7 add_symbolic

2.8 ex_round

2.9 augment_input

2.10 batch_jacobian

2.11 batch_hessian

2.12 create_dataset_from_data

2.13 model2param

2.14 get_derivative

三、spline.py

3.1 B_batch

3.2 coef2curve

3.3 curve2coef

3.4 extend_grid

四、总结 


一、kan目录

kan目录结构如下,包括了模型源码、检查点、实验以及assets等

e12295be65d94b3381e647242dc51eba.pngcc0f7d4a5a5148c995fcd44ad9bbbab6.png

 先了解一下这些文件/文件夹的大致信息:

  • kan\__init__.py:用于初始化Python包,方便使用时导入模块
  • kan\compiler.py:用于编译模型
  • kan\experiment.py:实验代码
  • kan\feynman.py:费曼函数,根据传入“name”的值确定函数,暂时没找到这个在哪里用到
  • kan\hypothesis.py:将函数进行线性分离,还包含一些画图函数
  • kan\KANLayer.py:KAN层的实现,看着就像是核心
  • kan\LBFGS.py:这个文件名似乎昨天见过,训练时的opt参数。L-BFGS是一种用于无约束优化问题的算法,它是一种拟牛顿方法,特别适用于大型稀疏问题。
  • kan\MLP.py:作者自己实现了一个MLP,应该使来与KAN做对比的
  • kan\MultKAN.py:在KANLayer的基础上实现的KAN类的定义,提供了关于构建和配置这种网络的详细信息。
  • kan\spline.py:样条函数的实现
  • kan\Symbolic_KANLayer.py:符号化的KAN层(?),暂时不知道这个是干啥的
  • kan\utils.py:通用模块
  • kan\.ipynb_checkpoints:看目录名,这个文件夹下存放的应该是检查点文件,但是似乎和模型的实现代码区别不大,没遇到过,还不知道有什么用。
  • kan\assets:这个目录下存放了两张图片,一张加号一张乘号,应该是对函数进行线性分离后,可视化时用的
  • kan\experiments:这个目录下是experiment1.ipynb,和昨天跑的hellokan差不多,今天再跑一下

二、utils.py

2.1 导入库和模块

import numpy as np
import torch
from sklearn.linear_model import LinearRegression
import sympy
import yaml
from sympy.utilities.lambdify import lambdify
import re
  1. numpy: NumPy 是一个强大的 Python 库,用于进行数值计算。它提供了大量的数学函数,用于数组操作、矩阵计算等。
  2. torch: PyTorch 是一个开源的机器学习库,由 Facebook 的 AI 研究团队开发。它提供了强大的 GPU 加速功能,用于深度学习模型的训练和推理。
  3. sklearn.linear_model.LinearRegression: scikit-learn 是一个开源机器学习库,它提供了许多机器学习算法的实现,包括线性回归。LinearRegression 类用于实现线性回归模型。
  4. sympy: SymPy 是一个用于符号数学的 Python 库。它允许用户进行代数计算、微积分、离散数学等。
  5. yaml: PyYAML 是一个用于解析和生成 YAML 数据的 Python 库。YAML 是一种类似于 JSON 的数据序列化格式,常用于配置文件。
  6. sympy.utilities.lambdifylambdify 是 SymPy 库中的一个函数,它可以将符号表达式转换为可调用的 Python 函数,这些函数可以接受 NumPy 或 PyTorch 数组作为输入。
  7. rere 是 Python 的正则表达式库,用于处理字符串。它提供了强大的模式匹配、搜索和替换功能。

 2.2 逆函数定义

f_inv = lambda x, y_th: ((x_th := 1/y_th), y_th/x_th*x * (torch.abs(x) < x_th) + torch.nan_to_num(1/x) * (torch.abs(x) >= x_th))
f_inv2 = lambda x, y_th: ((x_th := 1/y_th**(1/2)), y_th * (torch.abs(x) < x_th) + torch.nan_to_num(1/x**2) * (torch.abs(x) >= x_th))
f_inv3 = lambda x, y_th: ((x_th := 1/y_th**(1/3)), y_th/x_th*x * (torch.abs(x) < x_th) + torch.nan_to_num(1/x**3) * (torch.abs(x) >= x_th))
f_inv4 = lambda x, y_th: ((x_th := 1/y_th**(1/4)), y_th * (torch.abs(x) < x_th) + torch.nan_to_num(1/x**4) * (torch.abs(x) >= x_th))
f_inv5 = lambda x, y_th: ((x_th := 1/y_th**(1/5)), y_th/x_th*x * (torch.abs(x) < x_th) + torch.nan_to_num(1/x**5) * (torch.abs(x) >= x_th))
f_sqrt = lambda x, y_th: ((x_th := 1/y_th**2), x_th/y_th*x * (torch.abs(x) < x_th) + torch.nan_to_num(torch.sqrt(torch.abs(x))*torch.sign(x)) * (torch.abs(x) >= x_th))
f_power1d5 = lambda x, y_th: torch.abs(x)**1.5
f_invsqrt = lambda x, y_th: ((x_th := 1/y_th**2), y_th * (torch.abs(x) < x_th) + torch.nan_to_num(1/torch.sqrt(torch.abs(x))) * (torch.abs(x) >= x_th))
f_log = lambda x, y_th: ((x_th := torch.e**(-y_th)), - y_th * (torch.abs(x) < x_th) + torch.nan_to_num(torch.log(torch.abs(x))) * (torch.abs(x) >= x_th))
f_tan = lambda x, y_th: ((clip := x % torch.pi), (delta := torch.pi/2-torch.arctan(y_th)), - y_th/delta * (clip - torch.pi/2) * (torch.abs(clip - torch.pi/2) < delta) + torch.nan_to_num(torch.tan(clip)) * (torch.abs(clip - torch.pi/2) >= delta))
f_arctanh = lambda x, y_th: ((delta := 1-torch.tanh(y_th) + 1e-4), y_th * torch.sign(x) * (torch.abs(x) > 1 - delta) + torch.nan_to_num(torch.arctanh(x)) * (torch.abs(x) <= 1 - delta))
f_arcsin = lambda x, y_th: ((), torch.pi/2 * torch.sign(x) * (torch.abs(x) > 1) + torch.nan_to_num(torch.arcsin(x)) * (torch.abs(x) <= 1))
f_arccos = lambda x, y_th: ((), torch.pi/2 * (1-torch.sign(x)) * (torch.abs(x) > 1) + torch.nan_to_num(torch.arccos(x)) * (torch.abs(x) <= 1))
f_exp = lambda x, y_th: ((x_th := torch.log(y_th)), y_th * (x > x_th) + torch.exp(x) * (x <= x_th))

这些函数用于计算不同的数学函数的逆函数或相关操作。这些函数使用了PyTorch库,用于执行向量化操作和处理张量。

对于 f_inv

  • 函数参数说明:
    • x: 输入变量,通常是一个 torch 张量。
    • y_th: 阈值变量,用于确定 x 的处理方式。
  • 函数内部逻辑:

    • 计算 x_thx_th 被设定为 1/y_th
    • 条件处理
      • 如果 torch.abs(x) < x_th,则使用 y_th/x_th*x 进行计算。
      • 如果 torch.abs(x) >= x_th,则使用 torch.nan_to_num(1/x) 来处理,torch.nan_to_num 函数会将 NaN(非数字)值替换为0。

这些函数中的 x_th 是一个阈值,用于决定在哪个点开始使用近似方法。torch.nan_to_num 用于处理除以零的情况,返回 nan(不是数字)值。torch.abs 计算绝对值,torch.sign 返回符号,torch.log 和 torch.exp 分别计算自然对数和指数。torch.pi 是一个常量,表示π的值。torch.tanh 和 torch.arctanh 分别计算双曲正切和双曲反正切函数。

 这些函数用于反向传播计算场景。

测试:

from kan.utils import *

# 测试输入
x_test = torch.tensor([0.1, 1.0, 10.0])
y_th_test = torch.tensor([1.0, 1.0, 1.0])

# 测试所有函数
functions = [f_inv, f_inv2, f_inv3, f_inv4, f_inv5, f_sqrt, f_power1d5, f_invsqrt, f_log, f_tan, f_arctanh, f_arcsin, f_arccos, f_exp]
function_names = ["f_inv", "f_inv2", "f_inv3", "f_inv4", "f_inv5", "f_sqrt", "f_power1d5", "f_invsqrt", "f_log", "f_tan", "f_arctanh", "f_arcsin", "f_arccos", "f_exp"]

for name, func in zip(function_names, functions):
    result = func(x_test, y_th_test)
    print(f"{name} result: {result}")

f_inv result: (tensor([1., 1., 1.]), tensor([0.1000, 1.0000, 0.1000]))
f_inv2 result: (tensor([1., 1., 1.]), tensor([1.0000, 1.0000, 0.0100]))
f_inv3 result: (tensor([1., 1., 1.]), tensor([0.1000, 1.0000, 0.0010]))
f_inv4 result: (tensor([1., 1., 1.]), tensor([1.0000e+00, 1.0000e+00, 1.0000e-04]))
f_inv5 result: (tensor([1., 1., 1.]), tensor([1.0000e-01, 1.0000e+00, 1.0000e-05]))
f_sqrt result: (tensor([1., 1., 1.]), tensor([0.1000, 1.0000, 3.1623]))
f_power1d5 result: tensor([ 0.0316,  1.0000, 31.6228])
f_invsqrt result: (tensor([1., 1., 1.]), tensor([1.0000, 1.0000, 0.3162]))
f_log result: (tensor([0.3679, 0.3679, 0.3679]), tensor([-1.0000,  0.0000,  2.3026]))
f_tan result: (tensor([0.1000, 1.0000, 0.5752]), tensor([0.7854, 0.7854, 0.7854]), tensor([0.1003, 0.7268, 0.6484]))
f_arctanh result: (tensor([0.2385, 0.2385, 0.2385]), tensor([0.1003, 1.0000, 1.0000]))
f_arcsin result: ((), tensor([0.1002, 1.5708, 1.5708]))
f_arccos result: ((), tensor([1.4706, 0.0000, 0.0000]))
f_exp result: (tensor([0., 0., 0.]), tensor([1., 1., 1.]))

2.3 符号库定义

SYMBOLIC_LIB = {'x': (lambda x: x, lambda x: x, 1, lambda x, y_th: ((), x)),
                 'x^2': (lambda x: x**2, lambda x: x**2, 2, lambda x, y_th: ((), x**2)),
                 'x^3': (lambda x: x**3, lambda x: x**3, 3, lambda x, y_th: ((), x**3)),
                 'x^4': (lambda x: x**4, lambda x: x**4, 3, lambda x, y_th: ((), x**4)),
                 'x^5': (lambda x: x**5, lambda x: x**5, 3, lambda x, y_th: ((), x**5)),
                 '1/x': (lambda x: 1/x, lambda x: 1/x, 2, f_inv),
                 '1/x^2': (lambda x: 1/x**2, lambda x: 1/x**2, 2, f_inv2),
                 '1/x^3': (lambda x: 1/x**3, lambda x: 1/x**3, 3, f_inv3),
                 '1/x^4': (lambda x: 1/x**4, lambda x: 1/x**4, 4, f_inv4),
                 '1/x^5': (lambda x: 1/x**5, lambda x: 1/x**5, 5, f_inv5),
                 'sqrt': (lambda x: torch.sqrt(x), lambda x: sympy.sqrt(x), 2, f_sqrt),
                 'x^0.5': (lambda x: torch.sqrt(x), lambda x: sympy.sqrt(x), 2, f_sqrt),
                 'x^1.5': (lambda x: torch.sqrt(x)**3, lambda x: sympy.sqrt(x)**3, 4, f_power1d5),
                 '1/sqrt(x)': (lambda x: 1/torch.sqrt(x), lambda x: 1/sympy.sqrt(x), 2, f_invsqrt),
                 '1/x^0.5': (lambda x: 1/torch.sqrt(x), lambda x: 1/sympy.sqrt(x), 2, f_invsqrt),
                 'exp': (lambda x: torch.exp(x), lambda x: sympy.exp(x), 2, f_exp),
                 'log': (lambda x: torch.log(x), lambda x: sympy.log(x), 2, f_log),
                 'abs': (lambda x: torch.abs(x), lambda x: sympy.Abs(x), 3, lambda x, y_th: ((), torch.abs(x))),
                 'sin': (lambda x: torch.sin(x), lambda x: sympy.sin(x), 2, lambda x, y_th: ((), torch.sin(x))),
                 'cos': (lambda x: torch.cos(x), lambda x: sympy.cos(x), 2, lambda x, y_th: ((), torch.cos(x))),
                 'tan': (lambda x: torch.tan(x), lambda x: sympy.tan(x), 3, f_tan),
                 'tanh': (lambda x: torch.tanh(x), lambda x: sympy.tanh(x), 3, lambda x, y_th: ((), torch.tanh(x))),
                 'sgn': (lambda x: torch.sign(x), lambda x: sympy.sign(x), 3, lambda x, y_th: ((), torch.sign(x))),
                 'arcsin': (lambda x: torch.arcsin(x), lambda x: sympy.asin(x), 4, f_arcsin),
                 'arccos': (lambda x: torch.arccos(x), lambda x: sympy.acos(x), 4, f_arccos),
                 'arctan': (lambda x: torch.arctan(x), lambda x: sympy.atan(x), 4, lambda x, y_th: ((), torch.arctan(x))),
                 'arctanh': (lambda x: torch.arctanh(x), lambda x: sympy.atanh(x), 4, f_arctanh),
                 '0': (lambda x: x*0, lambda x: x*0, 0, lambda x, y_th: ((), x*0)),
                 'gaussian': (lambda x: torch.exp(-x**2), lambda x: sympy.exp(-x**2), 3, lambda x, y_th: ((), torch.exp(-x**2))),
                 #'cosh': (lambda x: torch.cosh(x), lambda x: sympy.cosh(x), 5),
                 #'sigmoid': (lambda x: torch.sigmoid(x), sympy.Function('sigmoid'), 4),
                 #'relu': (lambda x: torch.relu(x), relu),
}

字典SYMBOLIC_LIB是一个符号库,用于将数学表达式映射到相应的Python函数。这个库包含了各种数学函数和操作,以及它们的符号表示和相应的操作。

 

  • 键:
    • 'x': 字符串类型的键。
  • 值:
    • 第一个 lambda 函数lambda x: x,这个函数接受一个参数 x 并返回它,即它是一个恒等函数。
    • 第二个 lambda 函数:同样是 lambda x: x,功能和第一个函数相同。
    • 整数1,这是一个简单的整数值。
    • 第三个 lambda 函数lambda x, y_th: ((), x),这个函数接受两个参数 x 和 y_th,返回一个元组,其中第一个元素是空元组 (),第二个元素是参数 x

这个符号库可以用于自动微分、数值计算或者符号计算中,将数学表达式转换为可执行的Python函数,并提供了相应的导数信息。

测试:

def test_function(func_name, x_value):
    # 获取SYMBOLIC_LIB中对应的函数信息
    func_info = SYMBOLIC_LIB[func_name]
    
    # 使用torch计算函数值
    torch_func = func_info[0]
    torch_result = torch_func(x_value)
    
    # 使用sympy计算函数值
    sympy_func = func_info[1]
    sympy_result = sympy_func(x_value)
    
    # 打印结果
    print(f"Testing {func_name} with x = {x_value}")
    print(f"  torch result: {torch_result}")
    print(f"  sympy result: {sympy_result}")
    print(f"  Result is {'equal' if torch_result == sympy_result else 'not equal'} to sympy result")
    print()

# 测试不同的函数
test_values = torch.tensor([2])  # 选择一些测试值
for func_name in SYMBOLIC_LIB:
    for x_value in test_values:
        test_function(func_name, x_value)

Testing x with x = 2
  torch result: 2
  sympy result: 2
  Result is equal to sympy result

Testing x^2 with x = 2
  torch result: 4
  sympy result: 4
  Result is equal to sympy result

Testing x^3 with x = 2
  torch result: 8
  sympy result: 8
  Result is equal to sympy result

Testing x^4 with x = 2
  torch result: 16
  sympy result: 16
  Result is equal to sympy result

Testing x^5 with x = 2
  torch result: 32
  sympy result: 32
  Result is equal to sympy result

Testing 1/x with x = 2
  torch result: 0.5
  sympy result: 0.5
  Result is equal to sympy result

Testing 1/x^2 with x = 2
  torch result: 0.25
  sympy result: 0.25
  Result is equal to sympy result

Testing 1/x^3 with x = 2
  torch result: 0.125
  sympy result: 0.125
  Result is equal to sympy result

Testing 1/x^4 with x = 2
  torch result: 0.0625
  sympy result: 0.0625
  Result is equal to sympy result

Testing 1/x^5 with x = 2
  torch result: 0.03125
  sympy result: 0.03125
  Result is equal to sympy result

Testing sqrt with x = 2
  torch result: 1.4142135381698608
  sympy result: 1.41421356237310
  Result is not equal to sympy result

Testing x^0.5 with x = 2
  torch result: 1.4142135381698608
  sympy result: 1.41421356237310
  Result is not equal to sympy result

Testing x^1.5 with x = 2
  torch result: 2.8284268379211426
  sympy result: 2.82842712474619
  Result is not equal to sympy result

Testing 1/sqrt(x) with x = 2
  torch result: 0.7071067690849304
  sympy result: 0.707106781186547
  Result is not equal to sympy result

Testing 1/x^0.5 with x = 2
  torch result: 0.7071067690849304
  sympy result: 0.707106781186547
  Result is not equal to sympy result

Testing exp with x = 2
  torch result: 7.389056205749512
  sympy result: 7.38905609893065
  Result is not equal to sympy result

Testing log with x = 2
  torch result: 0.6931471824645996
  sympy result: 0.693147180559945
  Result is not equal to sympy result

Testing abs with x = 2
  torch result: 2
  sympy result: 2.00000000000000
  Result is equal to sympy result

Testing sin with x = 2
  torch result: 0.9092974066734314
  sympy result: 0.909297426825682
  Result is not equal to sympy result

Testing cos with x = 2
  torch result: -0.416146844625473
  sympy result: -0.416146836547142
  Result is not equal to sympy result

Testing tan with x = 2
  torch result: -2.185039758682251
  sympy result: -2.18503986326152
  Result is not equal to sympy result

Testing tanh with x = 2
  torch result: 0.9640275835990906
  sympy result: 0.964027580075817
  Result is not equal to sympy result

Testing sgn with x = 2
  torch result: 1
  sympy result: 1
  Result is equal to sympy result

Testing arcsin with x = 2
  torch result: nan
  sympy result: 1.5707963267949 - 1.31695789692482*I
  Result is not equal to sympy result

Testing arccos with x = 2
  torch result: nan
  sympy result: 1.31695789692482*I
  Result is not equal to sympy result

Testing arctan with x = 2
  torch result: 1.1071487665176392
  sympy result: 1.10714871779409
  Result is not equal to sympy result

Testing arctanh with x = 2
  torch result: nan
  sympy result: 0.549306144334055 - 1.5707963267949*I
  Result is not equal to sympy result

Testing 0 with x = 2
  torch result: 0
  sympy result: 0
  Result is equal to sympy result

Testing gaussian with x = 2
  torch result: 0.018315639346837997
  sympy result: 0.0183156388887342
  Result is not equal to sympy result

2.4 create_dataset

def create_dataset(f, 
                   n_var=2, 
                   f_mode = 'col',
                   ranges = [-1,1],
                   train_num=1000, 
                   test_num=1000,
                   normalize_input=False,
                   normalize_label=False,
                   device='cpu',
                   seed=0):
    '''
    create dataset
    
    Args:
    -----
        f : function
            the symbolic formula used to create the synthetic dataset
        ranges : list or np.array; shape (2,) or (n_var, 2)
            the range of input variables. Default: [-1,1].
        train_num : int
            the number of training samples. Default: 1000.
        test_num : int
            the number of test samples. Default: 1000.
        normalize_input : bool
            If True, apply normalization to inputs. Default: False.
        normalize_label : bool
            If True, apply normalization to labels. Default: False.
        device : str
            device. Default: 'cpu'.
        seed : int
            random seed. Default: 0.
        
    Returns:
    --------
        dataset : dic
            Train/test inputs/labels are dataset['train_input'], dataset['train_label'],
                        dataset['test_input'], dataset['test_label']
         
    Example
    -------
    >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
    >>> dataset = create_dataset(f, n_var=2, train_num=100)
    >>> dataset['train_input'].shape
    torch.Size([100, 2])
    '''

    np.random.seed(seed)
    torch.manual_seed(seed)

    if len(np.array(ranges).shape) == 1:
        ranges = np.array(ranges * n_var).reshape(n_var,2)
    else:
        ranges = np.array(ranges)
        
    
    train_input = torch.zeros(train_num, n_var)
    test_input = torch.zeros(test_num, n_var)
    for i in range(n_var):
        train_input[:,i] = torch.rand(train_num,)*(ranges[i,1]-ranges[i,0])+ranges[i,0]
        test_input[:,i] = torch.rand(test_num,)*(ranges[i,1]-ranges[i,0])+ranges[i,0]
                
    if f_mode == 'col':
        train_label = f(train_input)
        test_label = f(test_input)
    elif f_mode == 'row':
        train_label = f(train_input.T)
        test_label = f(test_input.T)
    else:
        print(f'f_mode {f_mode} not recognized')
        
    # if has only 1 dimension
    if len(train_label.shape) == 1:
        train_label = train_label.unsqueeze(dim=1)
        test_label = test_label.unsqueeze(dim=1)
        
    def normalize(data, mean, std):
            return (data-mean)/std
            
    if normalize_input == True:
        mean_input = torch.mean(train_input, dim=0, keepdim=True)
        std_input = torch.std(train_input, dim=0, keepdim=True)
        train_input = normalize(train_input, mean_input, std_input)
        test_input = normalize(test_input, mean_input, std_input)
        
    if normalize_label == True:
        mean_label = torch.mean(train_label, dim=0, keepdim=True)
        std_label = torch.std(train_label, dim=0, keepdim=True)
        train_label = normalize(train_label, mean_label, std_label)
        test_label = normalize(test_label, mean_label, std_label)

    dataset = {}
    dataset['train_input'] = train_input.to(device)
    dataset['test_input'] = test_input.to(device)

    dataset['train_label'] = train_label.to(device)
    dataset['test_label'] = test_label.to(device)

    return dataset

函数参数说明:

  • f: 一个函数,代表了生成数据集的数学模型。这个函数接受一个或多个输入变量并返回一个或多个输出值。
  • n_var: 数据集中的变量数量。默认值为2。
  • f_mode: 指定函数 f 的输入模式。如果为 'col',则输入为列向量;如果为 'row',则输入为行向量。默认值为 'col'
  • ranges: 输入变量的范围。默认为 [-1, 1]。如果是一个长度为2的数组,则假设所有变量都有相同的范围;如果是一个长度为 n_var 的数组,则每个变量有各自的范围。
  • train_num: 训练数据集的样本数量。默认值为1000。
  • test_num: 测试数据集的样本数量。默认值为1000。
  • normalize_input: 如果为 True,则对输入数据进行归一化处理。默认值为 False
  • normalize_label: 如果为 True,则对输出标签进行归一化处理。默认值为 False
  • device: 数据存储的设备,如 'cpu' 或 'cuda'。默认值为 'cpu'
  • seed: 随机数生成器的种子,用于确保每次运行时生成相同的数据集。默认值为0。

函数实现:

  1. 随机生成输入数据:使用 torch.rand 函数生成在指定范围内的随机数,然后将这些随机数转换为训练和测试数据集的输入。
  2. 计算输出数据:根据函数 f 计算训练和测试数据集的输出。
  3. 数据归一化:如果 normalize_input 或 normalize_label 为 True,则对输入数据和输出数据进行归一化处理。
  4. 数据存储:将处理后的训练和测试数据集存储在字典 dataset 中,并将数据移动到指定的设备。

测试:

# 定义一个简单的函数,用于创建数据集
def f(x):
    return torch.sin(x[:, 0]) * torch.cos(x[:, 1]) + x[:, 1]**2

# 调用 create_dataset 函数
def test_create_dataset():
    dataset = create_dataset(
        f=f,
        n_var=2,
        train_num=100,
        test_num=100,
        normalize_input=False,
        normalize_label=False,
        device='cpu',
        seed=0
    )

    # 检查数据集的结构
    assert 'train_input' in dataset, "Dataset does not contain 'train_input'"
    assert 'train_label' in dataset, "Dataset does not contain 'train_label'"
    assert 'test_input' in dataset, "Dataset does not contain 'test_input'"
    assert 'test_label' in dataset, "Dataset does not contain 'test_label'"

    # 检查输入和标签的形状
    assert dataset['train_input'].shape == torch.Size([100, 2]), "Train input shape is incorrect"
    assert dataset['train_label'].shape == torch.Size([100, 1]), "Train label shape is incorrect"
    assert dataset['test_input'].shape == torch.Size([100, 2]), "Test input shape is incorrect"
    assert dataset['test_label'].shape == torch.Size([100, 1]), "Test label shape is incorrect"

    print("All tests passed.")

# 运行测试脚本
test_create_dataset()

 All tests passed.

2.5 fit_params

def fit_params(x, y, fun, a_range=(-10,10), b_range=(-10,10), grid_number=101, iteration=3, verbose=True, device='cpu'):
    '''
    fit a, b, c, d such that
    
    .. math::
        |y-(cf(ax+b)+d)|^2
        
    is minimized. Both x and y are 1D array. Sweep a and b, find the best fitted model.
    
    Args:
    -----
        x : 1D array
            x values
        y : 1D array
            y values
        fun : function
            symbolic function
        a_range : tuple
            sweeping range of a
        b_range : tuple
            sweeping range of b
        grid_num : int
            number of steps along a and b
        iteration : int
            number of zooming in
        verbose : bool
            print extra information if True
        device : str
            device
        
    Returns:
    --------
        a_best : float
            best fitted a
        b_best : float
            best fitted b
        c_best : float
            best fitted c
        d_best : float
            best fitted d
        r2_best : float
            best r2 (coefficient of determination)
    
    Example
    -------
    >>> num = 100
    >>> x = torch.linspace(-1,1,steps=num)
    >>> noises = torch.normal(0,1,(num,)) * 0.02
    >>> y = 5.0*torch.sin(3.0*x + 2.0) + 0.7 + noises
    >>> fit_params(x, y, torch.sin)
    r2 is 0.9999727010726929
    (tensor([2.9982, 1.9996, 5.0053, 0.7011]), tensor(1.0000))
    '''
    # fit a, b, c, d such that y=c*fun(a*x+b)+d; both x and y are 1D array.
    # sweep a and b, choose the best fitted model   
    for _ in range(iteration):
        a_ = torch.linspace(a_range[0], a_range[1], steps=grid_number, device=device)
        b_ = torch.linspace(b_range[0], b_range[1], steps=grid_number, device=device)
        a_grid, b_grid = torch.meshgrid(a_, b_, indexing='ij')
        post_fun = fun(a_grid[None,:,:] * x[:,None,None] + b_grid[None,:,:])
        x_mean = torch.mean(post_fun, dim=[0], keepdim=True)
        y_mean = torch.mean(y, dim=[0], keepdim=True)
        numerator = torch.sum((post_fun - x_mean)*(y-y_mean)[:,None,None], dim=0)**2
        denominator = torch.sum((post_fun - x_mean)**2, dim=0)*torch.sum((y - y_mean)[:,None,None]**2, dim=0)
        r2 = numerator/(denominator+1e-4)
        r2 = torch.nan_to_num(r2)
        
        
        best_id = torch.argmax(r2)
        a_id, b_id = torch.div(best_id, grid_number, rounding_mode='floor'), best_id % grid_number
        
        
        if a_id == 0 or a_id == grid_number - 1 or b_id == 0 or b_id == grid_number - 1:
            if _ == 0 and verbose==True:
                print('Best value at boundary.')
            if a_id == 0:
                a_range = [a_[0], a_[1]]
            if a_id == grid_number - 1:
                a_range = [a_[-2], a_[-1]]
            if b_id == 0:
                b_range = [b_[0], b_[1]]
            if b_id == grid_number - 1:
                b_range = [b_[-2], b_[-1]]
            
        else:
            a_range = [a_[a_id-1], a_[a_id+1]]
            b_range = [b_[b_id-1], b_[b_id+1]]
            
    a_best = a_[a_id]
    b_best = b_[b_id]
    post_fun = fun(a_best * x + b_best)
    r2_best = r2[a_id, b_id]
    
    if verbose == True:
        print(f"r2 is {r2_best}")
        if r2_best < 0.9:
            print(f'r2 is not very high, please double check if you are choosing the correct symbolic function.')

    post_fun = torch.nan_to_num(post_fun)
    reg = LinearRegression().fit(post_fun[:,None].detach().cpu().numpy(), y.detach().cpu().numpy())
    c_best = torch.from_numpy(reg.coef_)[0].to(device)
    d_best = torch.from_numpy(np.array(reg.intercept_)).to(device)
    return torch.stack([a_best, b_best, c_best, d_best]), r2_best

这段代码定义了一个名为 fit_params 的函数,用于拟合一组数据点到一个给定的符号函数 fun。目标是找到一组参数 abcd,使得 y 与 c*fun(a*x+b)+d 的差的平方和最小。

函数参数:

  • x 和 y:分别为输入和输出数据的 1D 数组。
  • fun:一个符号函数,用于拟合数据。
  • a_range 和 b_range:用于参数 a 和 b 的搜索范围。
  • grid_number:在 a_range 和 b_range 上划分的网格点数。
  • iteration:迭代次数,用于逐步缩小参数搜索范围。
  • verbose:如果为 True,则在每次迭代后输出额外信息。
  • device:执行计算的设备(如 'cpu' 或 'cuda')。

函数逻辑:

  1. 初始化参数搜索范围:首先,定义了参数 a 和 b 的搜索范围,并使用 torch.linspace 生成网格点。
  2. 迭代优化:通过 iteration 次迭代,每次迭代中,函数会计算所有参数组合下的残差平方和(r2),并找到最小的 r2 值对应的参数组合。
    1. 在每次迭代中,计算 fun 函数在当前参数组合下的输出,然后计算输出与 y 的平均值,以及残差平方和。
    2. 通过比较残差平方和,找到最佳参数组合。
  3. 边界处理:如果最佳参数组合位于搜索范围的边界上,函数会调整搜索范围以进行下一次迭代。
  4. 最终拟合:在迭代结束后,使用找到的最佳参数组合 a 和 b 来计算 c 和 d,并返回最佳参数组合和 r2 值。
  5. 回归分析:最后,使用线性回归来进一步拟合数据,计算 c 和 d 的值。

测试:

# 定义测试函数
def test_fit_params():
    num = 100
    x = torch.linspace(-1, 1, steps=num)
    noises = torch.normal(0, 1, (num,)) * 0.02
    y = 5.0 * torch.sin(3.0 * x + 2.0) + 0.7 + noises

    def sin_fun(x):
        return torch.sin(x)

    result, r2_best = fit_params(x, y, sin_fun)
    print(result)
    print(r2_best)

# 运行测试
test_fit_params()

r2 is 0.9999620914459229

tensor([-2.9993, -2.0006, -5.0089, 0.7001])

tensor(1.0000)

 这网格寻参有点东西啊

2.6 sparse_mask

def sparse_mask(in_dim, out_dim):
    '''
    get sparse mask
    '''
    in_coord = torch.arange(in_dim) * 1/in_dim + 1/(2*in_dim)
    out_coord = torch.arange(out_dim) * 1/out_dim + 1/(2*out_dim)

    dist_mat = torch.abs(out_coord[:,None] - in_coord[None,:])
    in_nearest = torch.argmin(dist_mat, dim=0)
    in_connection = torch.stack([torch.arange(in_dim), in_nearest]).permute(1,0)
    out_nearest = torch.argmin(dist_mat, dim=1)
    out_connection = torch.stack([out_nearest, torch.arange(out_dim)]).permute(1,0)
    all_connection = torch.cat([in_connection, out_connection], dim=0)
    mask = torch.zeros(in_dim, out_dim)
    mask[all_connection[:,0], all_connection[:,1]] = 1.
    
    return mask

这个 sparse_mask 函数的目的是生成一个稀疏掩码矩阵,该矩阵用于表示输入维度(in_dim)和输出维度(out_dim)之间的稀疏连接关系。

  1. 坐标生成:
    1. in_coord = torch.arange(in_dim) * 1/in_dim + 1/(2*in_dim): 生成一个长度为 in_dim 的张量 in_coord,其中每个元素是其在维度上的归一化坐标。坐标范围从 1/(2*in_dim) 到 1 - 1/(2*in_dim)
    2. out_coord = torch.arange(out_dim) * 1/out_dim + 1/(2*out_dim): 同样地,生成一个长度为 out_dim 的张量 out_coord,表示输出维度的归一化坐标。
  2. 距离矩阵计算:
    1. dist_mat = torch.abs(out_coord[:,None] - in_coord[None,:]): 计算输出坐标和输入坐标之间的欧几里得距离的绝对值,形成一个 (out_dim, in_dim) 的距离矩阵 dist_mat
  3. 最近邻索引:
    1. in_nearest = torch.argmin(dist_mat, dim=0): 对于输出坐标的每一行,找到与输入坐标最接近的索引,并存储在 in_nearest 中。
    2. out_nearest = torch.argmin(dist_mat, dim=1): 对于输入坐标的每一列,找到与输出坐标最接近的索引,并存储在 out_nearest 中。
  4. 连接关系:
    1. in_connection = torch.stack([torch.arange(in_dim), in_nearest]).permute(1,0): 将输入维度的索引和对应的最近邻输出维度的索引组合成一个张量,并进行转置。
    2. out_connection = torch.stack([out_nearest, torch.arange(out_dim)]).permute(1,0): 类似地,组合并转置得到输出维度的连接关系张量。
    3. all_connection = torch.cat([in_connection, out_connection], dim=0): 将输入和输出的连接关系张量沿着第一个维度(行方向)拼接起来。
  5. 掩码矩阵:
    1. mask = torch.zeros(in_dim, out_dim): 初始化一个全零的 (in_dim, out_dim) 矩阵,用于存储稀疏掩码。
    2. mask[all_connection[:,0], all_connection[:,1]] = 1.: 根据连接关系张量 all_connection 中的索引,将掩码矩阵中对应的位置设置为 1。
  6. 返回结果:
    1. return mask: 返回填充了 1 的稀疏掩码矩阵。

这个函数可以用于构建稀疏连接的神经网络层,它可以帮助定义节点之间的连接关系。

测试:

# 测试函数
def test_sparse_mask():
    in_dim = 2
    out_dim = 5

    mask = sparse_mask(in_dim, out_dim)

    # 打印掩码矩阵
    print("Mask matrix:")
    print(mask)


# 运行测试
test_sparse_mask()

Mask matrix:
tensor([[1., 1., 1., 0., 0.],
            [0., 0., 0., 1., 1.]])

 2.7 add_symbolic

def add_symbolic(name, fun, c=1, fun_singularity=None):
    '''
    add a symbolic function to library
    
    Args:
    -----
        name : str
            name of the function
        fun : fun
            torch function or lambda function
    
    Returns:
    --------
        None
    
    Example
    -------
    >>> print(SYMBOLIC_LIB['Bessel'])
    KeyError: 'Bessel'
    >>> add_symbolic('Bessel', torch.special.bessel_j0)
    >>> print(SYMBOLIC_LIB['Bessel'])
    (<built-in function special_bessel_j0>, Bessel)
    '''
    exec(f"globals()['{name}'] = sympy.Function('{name}')")
    if fun_singularity==None:
        fun_singularity = fun
    SYMBOLIC_LIB[name] = (fun, globals()[name], c, fun_singularity)

定义了一个名为 add_symbolic 的函数,用于在 Python 的全局命名空间中添加一个符号函数,并将其存储在 SYMBOLIC_LIB 字典中。

参数说明:

  1. name:函数的名称。
  2. fun:一个函数对象,可以是 torch 中的函数或者一个使用 lambda 定义的函数。
  3. c:一个常数,默认值为 1。
  4. fun_singularity:一个可选参数,表示该函数的奇异点,如果未提供,则默认使用 fun

代码解释:

  1. 使用 exec 添加符号函数:通过 exec 语句,将 name 作为符号函数的名称添加到全局命名空间中。这会创建一个与 name 相关的符号函数,例如 Bessel
  2. 添加至 SYMBOLIC_LIB:如果 fun_singularity 未被提供,则它会默认等于 fun。然后将一个元组 (fun, 符号函数对象, c, fun_singularity) 添加到 SYMBOLIC_LIB 字典中,其中:
    1. fun 是原始函数对象。
    2. 符号函数对象 是通过 sympy.Function 创建的符号函数。
    3. c 是常数。
    4. fun_singularity 是函数的奇异点定义。

测试:

# 添加符号函数
add_symbolic('Bessel', torch.special.bessel_j0)

# 检查添加是否成功
print(SYMBOLIC_LIB['Bessel'])  # 应该输出 (torch.special.bessel_j0, Bessel)

# 添加另一个符号函数
add_symbolic('Sine', torch.sin)

# 检查添加是否成功
print(SYMBOLIC_LIB['Sine'])  # 应该输出 (torch.sin, Sine)

# 使用添加的符号函数创建表达式
x1 = torch.tensor(0.5)
x2 = sympy.Symbol('x')
expr = SYMBOLIC_LIB['Bessel'][0](x1) + SYMBOLIC_LIB['Sine'][1](x2)
print(expr)  # 应该输出一个包含 Bessel 函数和 sin 函数的表达式

# 进行符号计算(例如微分)
diff_expr = expr.diff(x2)
print(diff_expr)  # 应该输出表达式的导数

 (<built-in function special_bessel_j0>, Bessel, 1, <built-in function special_bessel_j0>)
(<built-in method sin of type object at 0x00007FFCFCAC6E90>, Sine, 1, <built-in method sin of type object at 0x00007FFCFCAC6E90>)
Sine(x) + 0.93846982717514
Derivative(Sine(x), x)

2.8 ex_round

def ex_round(ex1, n_digit):
    '''
    rounding the numbers in an expression to certain floating points
    
    Args:
    -----
        ex1 : sympy expression
        n_digit : int
        
    Returns:
    --------
        ex2 : sympy expression
    
    Example
    -------
    >>> from kan.utils import *
    >>> from sympy import *
    >>> input_vars = a, b = symbols('a b')
    >>> expression = 3.14534242 * exp(sin(pi*a) + b**2) - 2.32345402
    >>> ex_round(expression, 2)
    '''
    ex2 = ex1
    for a in sympy.preorder_traversal(ex1):
        if isinstance(a, sympy.Float):
            ex2 = ex2.subs(a, round(a, n_digit))
    return ex2

这段 Python 代码定义了一个名为 ex_round 的函数,用于将一个 sympy 表达式中的浮点数舍入到指定的小数位数。

函数的工作方式是通过遍历输入表达式中的每个元素。如果元素是一个 sympy.Float 类型(即浮点数),就使用 round 函数将其舍入到指定的小数位数,并使用 subs 方法替换原始的浮点数。

这样,最终返回的表达式 ex2 就是经过舍入处理后的结果。

测试:

from sympy import *
input_vars = a, b = symbols('a b')
expression = 3.14534242 * exp(sin(pi*a) + b**2) - 2.32345402
ex_round(expression, 2)

eq?3.15e%5E%7Bb%5E%7B2%7D&plus;%28%5Cpi%20a%29%7D-2.32 

2.9 augment_input

def augment_input(orig_vars, aux_vars, x):
    '''
    augment inputs
    
    Args:
    -----
        orig_vars : list of sympy symbols
        aux_vars : list of auxiliary symbols
        x : inputs
        
    Returns:
    --------
        augmented inputs
    
    Example
    -------
    >>> from kan.utils import *
    >>> from sympy import *
    >>> orig_vars = a, b = symbols('a b')
    >>> aux_vars = [a + b, a * b]
    >>> x = torch.rand(100, 2)
    >>> augment_input(orig_vars, aux_vars, x).shape
    '''
    # if x is a tensor
    if isinstance(x, torch.Tensor):
        
        aux_values = torch.tensor([]).to(x.device)

        for aux_var in aux_vars:
            func = lambdify(orig_vars, aux_var,'numpy') # returns a numpy-ready function
            aux_value = torch.from_numpy(func(*[x[:,[i]].numpy() for i in range(len(orig_vars))]))
            aux_values = torch.cat([aux_values, aux_value], dim=1)
            
        x = torch.cat([aux_values, x], dim=1)

    # if x is a dataset
    elif isinstance(x, dict):
        x['train_input'] = augment_input(orig_vars, aux_vars, x['train_input'])
        x['test_input'] = augment_input(orig_vars, aux_vars, x['test_input'])
        
    return x

这段代码定义了一个名为 augment_input 的函数,它的目的是将额外的辅助变量(aux_vars)添加到原始变量(orig_vars)的输入中。这个函数可以处理两种类型的输入:张量(tensor)和字典(通常表示数据集)。

参数

  • orig_vars:一个包含原始符号变量的列表。
  • aux_vars:一个包含辅助符号变量的列表。
  • x:输入数据,可以是张量或字典。

返回值

  • augmented inputs:添加了辅助变量后的输入数据。

代码解释

  1. 检查输入类型

    • 如果 x 是一个张量(torch.Tensor),则进行以下操作:

      • 初始化一个空的张量 aux_values 用于存储辅助变量的值。
      • 遍历 aux_vars 列表中的每个辅助变量。
      • 对于每个辅助变量,使用 lambdify 函数将辅助变量与原始变量关联起来,生成一个可以在 NumPy 环境中运行的函数。
      • 使用这个函数计算每个辅助变量的值,并将原始输入 x 转换为 NumPy 数组,以便与函数一起使用。
      • 将计算出的辅助变量的值转换为 PyTorch 张量,并与 aux_values 张量拼接,增加一个维度。
      • 最后,将原始输入 x 与新的辅助变量张量拼接,以增加额外的输入维度。
    • 如果 x 是一个字典,这通常表示一个数据集,则递归地对 x 中的 'train_input' 和 'test_input' 进行相同的处理。

  2. 返回值

    • 无论输入类型如何,函数最终都会返回处理后的输入数据 x

 测试:

示例中的注释部分提供了一个使用 augment_input 函数的例子,其中 orig_vars 包含两个符号 a 和 baux_vars 包含两个辅助变量 a + b 和 a * bx 是一个形状为 (100, 2) 的随机张量。调用 augment_input 函数后,输出的张量形状将会是 (100, 4),因为它增加了两个辅助变量的维度。

orig_vars = a, b = symbols('a b')
aux_vars = [a + b, a * b]
x = torch.rand(100, 2)
augment_input(orig_vars, aux_vars, x).shape

 torch.Size([100, 4])

2.10 batch_jacobian

def batch_jacobian(func, x, create_graph=False):
    '''
    jacobian
    
    Args:
    -----
        func : function or model
        x : inputs
        create_graph : bool
        
    Returns:
    --------
        jacobian
    
    Example
    -------
    >>> from kan.utils import batch_jacobian
    >>> x = torch.normal(0,1,size=(100,2))
    >>> model = lambda x: x[:,[0]] + x[:,[1]]
    >>> batch_jacobian(model, x)
    '''
    # x in shape (Batch, Length)
    def _func_sum(x):
        return func(x).sum(dim=0)
    return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph)[0]

这段代码定义了一个名为 batch_jacobian 的函数,其目的是计算给定函数 func 在输入 x 上的雅可比矩阵(Jacobian matrix)。雅可比矩阵是一个矩阵,其中每一行对应函数对输入变量的偏导数。

参数

  • func:一个函数或模型,它接受输入 x 并返回输出。
  • x:输入数据,其形状为 (Batch, Length)
  • create_graph:一个布尔值,指示是否为计算梯度创建计算图。如果设置为 True,则可以计算更高阶的导数。

返回值

  • jacobian:输入 x 上函数 func 的雅可比矩阵。

代码解释

  1. 内部函数 _func_sum

    • 定义了一个内部函数 _func_sum,它计算输入 x 通过 func 函数后的输出,并对所有元素进行求和。
  2. 计算雅可比矩阵

    • 使用 torch.autograd.functional.jacobian 函数计算 _func_sum 关于输入 x 的雅可比矩阵。
    • jacobian 函数接受 _func_sum 和 x 作为输入,并设置 create_graph 参数,以便根据需要创建计算图。
    • 最后,通过索引 [0] 获取雅可比矩阵的第一个元素,即输出矩阵。
  3. 返回雅可比矩阵

    • 函数返回计算出的雅可比矩阵。

测试:

x = torch.normal(0,10,size=(10,2))
print(x)
model = lambda x: 2*x[:,[0]]**2 + x[:,[1]]**3
batch_jacobian(model, x)

tensor([[ -3.2299,  -5.6995],
        [ -0.8403,   3.2892],
        [  1.5103, -11.5371],
        [ -9.0327,   1.7513],
        [  7.1268,  -4.4041],
        [-12.4855,  16.1884],
        [  5.5421, -14.9643],
        [  0.9730, -11.0945],
        [-12.5421,  -4.8494],
        [ -1.1888,   7.5661]])
 

tensor([[-12.9195,  97.4539],
        [ -3.3611,  32.4565],
        [  6.0412, 399.3152],
        [-36.1308,   9.2016],
        [ 28.5072,  58.1878],
        [-49.9420, 786.1920],
        [ 22.1682, 671.7947],
        [  3.8919, 369.2659],
        [-50.1685,  70.5492],
        [ -4.7551, 171.7395]])

2.11 batch_hessian

def batch_hessian(model, x, create_graph=False):
    '''
    hessian
    
    Args:
    -----
        func : function or model
        x : inputs
        create_graph : bool
        
    Returns:
    --------
        jacobian
    
    Example
    -------
    >>> from kan.utils import batch_hessian
    >>> x = torch.normal(0,1,size=(100,2))
    >>> model = lambda x: x[:,[0]]**2 + x[:,[1]]**2
    >>> batch_hessian(model, x)
    '''
    # x in shape (Batch, Length)
    jac = lambda x: batch_jacobian(model, x, create_graph=True)
    def _jac_sum(x):
        return jac(x).sum(dim=0)
    return torch.autograd.functional.jacobian(_jac_sum, x, create_graph=create_graph).permute(1,0,2)

这段代码定义了一个名为 batch_hessian 的函数,用于计算模型的批量海森矩阵(Hessian matrix)。海森矩阵是一个函数在某点的二阶偏导数矩阵,对于优化和理解模型行为非常有用。

参数解释

  • model: 这是一个函数或模型对象,它接受输入 x 并返回输出。
  • x: 输入数据,形状为 (Batch, Length),其中 Batch 是数据集的大小,Length 是输入向量的维度。
  • create_graph: 一个布尔值,如果为 True,则返回的计算图将包含额外的图节点,这在需要计算高阶导数时有用。

返回值

  • jacobian: 批量雅可比矩阵,形状为 (Batch, Length, Length),表示对每个输入样本的每个维度对所有输出维度的偏导数。

函数实现细节

  1. 计算雅可比矩阵:首先,定义一个内部函数 _jac_sum,它计算模型对输入 x 的雅可比矩阵,并对矩阵的每一列求和。这样做的目的是为了在计算海森矩阵时减少维度。

  2. 计算海森矩阵:使用 torch.autograd.functional.jacobian 函数计算 _jac_sum 对输入 x 的海森矩阵。这个函数返回的矩阵形状为 (Batch, Length, Length, Length),表示对每个输入样本的每个维度对所有输出维度的偏导数矩阵。

  3. 调整输出形状:最后,通过调用 .permute(1,0,2) 来调整输出矩阵的维度,使其形状为 (Batch, Length, Length),与通常期望的海森矩阵形状一致。

测试:

x = torch.normal(0,10,size=(10,2))
print(x)
model = lambda x: x[:,[0]]**2 + 2*x[:,[1]]**3
batch_hessian(model, x)

tensor([[  6.4984,  18.0226],
        [  3.3973,  -0.4544],
        [-17.2917,  -9.5806],
        [  8.0034, -26.3708],
        [  6.4028,   7.4857],
        [ -2.9699,  10.9121],
        [ -3.7851, -11.1280],
        [ 11.3358,  10.0328],
        [  3.9513,   4.9699],
        [ -1.2890,   0.8191]])
 

tensor([[[  38.9905,    0.0000],
         [   0.0000,  216.2708]],

        [[  20.3835,    0.0000],
         [   0.0000,   -5.4534]],

        [[-103.7502,    0.0000],
         [   0.0000, -114.9676]],

        [[  48.0205,    0.0000],
         [   0.0000, -316.4493]],

        [[  38.4165,    0.0000],
         [   0.0000,   89.8279]],

        [[ -17.8196,    0.0000],
         [   0.0000,  130.9453]],

        [[ -22.7108,    0.0000],
         [   0.0000, -133.5359]],

        [[  68.0151,    0.0000],
         [   0.0000,  120.3941]],

        [[  23.7078,    0.0000],
         [   0.0000,   59.6382]],

        [[  -7.7341,    0.0000],
         [   0.0000,    9.8288]]])

 雅可比矩阵:[-12.9195,  97.4539],,

海森矩阵:[[  38.9905,    0.0000], [   0.0000,  216.2708]],

形状存在略微的区别

2.12 create_dataset_from_data

def create_dataset_from_data(inputs, labels, train_ratio=0.8, device='cpu'):
    '''
    create dataset from data
    
    Args:
    -----
        inputs : 2D torch.float
        labels : 2D torch.float
        train_ratio : float
            the ratio of training fraction
        device : str
        
    Returns:
    --------
        dataset (dictionary)
    
    Example
    -------
    >>> from kan.utils import create_dataset_from_data
    >>> x = torch.normal(0,1,size=(100,2))
    >>> y = torch.normal(0,1,size=(100,1))
    >>> dataset = create_dataset_from_data(x, y)
    >>> dataset['train_input'].shape
    '''
    num = inputs.shape[0]
    train_id = np.random.choice(num, int(num*train_ratio), replace=False)
    test_id = list(set(np.arange(num)) - set(train_id))
    dataset = {}
    dataset['train_input'] = inputs[train_id].detach().to(device)
    dataset['test_input'] = inputs[test_id].detach().to(device)
    dataset['train_label'] = labels[train_id].detach().to(device)
    dataset['test_label'] = labels[test_id].detach().to(device)
    
    return dataset

create_dataset_from_data函数用于从给定的数据(输入和标签)创建训练集和测试集。

函数参数

  • inputs : torch.float 类型的 2D 张量,表示输入数据。
  • labels : torch.float 类型的 2D 张量,表示对应的标签数据。
  • train_ratio : float 类型,表示训练数据与总数据的比例(默认为 0.8)。
  • device : str 类型,表示数据应该放在哪个设备上(默认为 'cpu')。

函数返回

  • dataset : dict 类型,包含训练集输入、训练集标签、测试集输入和测试集标签。

函数实现

  1. 计算输入数据的总数量。
  2. 使用 np.random.choice 从数据中随机选择一定数量的样本作为训练集。
  3. 通过计算剩余样本作为测试集。
  4. 将训练集和测试集的输入和标签分别存储在字典中,并将它们移动到指定的设备上。

 测试:

x = torch.normal(0,1,size=(100,2))
y = torch.normal(0,1,size=(100,1))
dataset = create_dataset_from_data(x, y)
dataset['train_input'].shape

 torch.Size([80, 2])

2.13 model2param

def model2param(model):
    '''
    turn model parameters into a flattened vector
    '''
    p = torch.tensor([]).to(model.device)
    for params in model.parameters():
        p = torch.cat([p, params.reshape(-1,)], dim=0)
    return p

model2param函数的主要目的是将模型的所有参数合并成一个一维张量。这个过程通常在需要对模型参数进行批量操作(如梯度计算、参数初始化等)时使用。

函数参数

  • model : PyTorch模型对象。这个模型包含要转换为一维向量的所有参数。

函数实现

函数首先定义了一个空的torch.tensor对象p,并将model的设备属性(device)复制给p,确保p与模型参数位于同一设备上。

接下来,函数遍历模型的所有参数(params),将每个参数的形状调整为一维(params.reshape(-1,)),然后使用torch.cat函数将所有参数连接成一个一维张量。-1作为形状参数表示自动计算维度,使得所有参数向量可以被正确连接。

最后,函数返回合并后的参数向量p

2.14 get_derivative

def get_derivative(model, inputs, labels, derivative='hessian', loss_mode='pred', reg_metric='w', lamb=0., lamb_l1=1., lamb_entropy=0.):
    '''
    compute the jacobian/hessian of loss wrt to model parameters
    
    Args:
    -----
        inputs : 2D torch.float
        labels : 2D torch.float
        derivative : str
            'jacobian' or 'hessian'
        device : str
        
    Returns:
    --------
        jacobian or hessian
    '''
    def get_mapping(model):

        mapping = {}
        name = 'model1'

        keys = list(model.state_dict().keys())
        for key in keys:

            y = re.findall(".[0-9]+", key)
            if len(y) > 0:
                y = y[0][1:]
                x = re.split(".[0-9]+", key)
                mapping[key] = name + '.' + x[0] + '[' + y + ']' + x[1]


            y = re.findall("_[0-9]+", key)
            if len(y) > 0:
                y = y[0][1:]
                x = re.split(".[0-9]+", key)
                mapping[key] = name + '.' + x[0] + '[' + y + ']'

        return mapping

    
    #model1 = copy.deepcopy(model)
    model1 = model.copy()
    mapping = get_mapping(model)
   
    # collect keys and shapes
    keys = list(model.state_dict().keys())
    shapes = []

    for params in model.parameters():
        shapes.append(params.shape)


    # turn a flattened vector to model params
    def param2statedict(p, keys, shapes):

        new_state_dict = {}

        start = 0
        n_group = len(keys)
        for i in range(n_group):
            shape = shapes[i]
            n_params = torch.prod(torch.tensor(shape))
            new_state_dict[keys[i]] = p[start:start+n_params].reshape(shape)
            start += n_params

        return new_state_dict
    
    def differentiable_load_state_dict(mapping, state_dict, model1):

        for key in keys:
            if mapping[key][-1] != ']':
                exec(f"del {mapping[key]}")
            exec(f"{mapping[key]} = state_dict[key]")
            

    # input: p, output: output
    def get_param2loss_fun(inputs, labels):

        def param2loss_fun(p):

            p = p[0]
            state_dict = param2statedict(p, keys, shapes)
            # this step is non-differentiable
            #model.load_state_dict(state_dict)
            differentiable_load_state_dict(mapping, state_dict, model1)
            if loss_mode == 'pred':
                pred_loss = torch.mean((model1(inputs) - labels)**2, dim=(0,1), keepdim=True)
                loss = pred_loss
            elif loss_mode == 'reg':
                reg_loss = model1.get_reg(reg_metric=reg_metric, lamb_l1=lamb_l1, lamb_entropy=lamb_entropy) * torch.ones(1,1)
                loss = reg_loss
            elif loss_mode == 'all':
                pred_loss = torch.mean((model1(inputs) - labels)**2, dim=(0,1), keepdim=True)
                reg_loss = model1.get_reg(reg_metric=reg_metric, lamb_l1=lamb_l1, lamb_entropy=lamb_entropy) * torch.ones(1,1)
                loss = pred_loss + lamb * reg_loss
            return loss

        return param2loss_fun
    
    fun = get_param2loss_fun(inputs, labels)    
    p = model2param(model)[None,:]
    if derivative == 'hessian':
        result = batch_hessian(fun, p)
    elif derivative == 'jacobian':
        result = batch_jacobian(fun, p)
    return result

get_derivative函数旨在计算模型参数相对于损失函数的雅可比(Jacobian)或海森(Hessian)矩阵。以下是该函数的详细说明:

函数参数

  • model : 模型对象,通常是一个PyTorch模型。
  • inputs : 输入数据,一个2D的torch.float张量。
  • labels : 标签数据,一个2D的torch.float张量。
  • derivative : 字符串,指定计算雅可比('jacobian')还是海森('hessian')矩阵。
  • loss_mode : 字符串,指定损失函数的模式('pred', 'reg', 'all')。
  • reg_metric : 字符串,指定正则化指标('w'表示权重)。
  • lamb : 浮点数,用于平衡损失和正则化项。
  • lamb_l1 : 浮点数,用于L1正则化项的权重。
  • lamb_entropy : 浮点数,用于熵正则化项的权重。

函数返回

  • 返回计算得到的雅可比或海森矩阵。

函数实现

函数首先定义了几个辅助函数:

  • get_mapping:用于获取模型参数的命名映射。
  • param2statedict:将参数向量转换为模型的状态字典。
  • differentiable_load_state_dict:将状态字典加载到模型中,以便计算梯度。
  • get_param2loss_fun:创建一个函数,该函数接受参数向量并返回损失。

然后,函数定义了fun,这是一个接受参数向量并返回损失值的函数。根据loss_mode参数,损失可以是预测损失、正则化损失或两者的组合。

最后,根据derivative参数,使用batch_hessianbatch_jacobian函数计算雅可比或海森矩阵。

三、spline.py

import torch

 这一部分的内容是纯torch实现的。

3.1 B_batch

def B_batch(x, grid, k=0, extend=True, device='cpu'):
    '''
    evaludate x on B-spline bases
    
    Args:
    -----
        x : 2D torch.tensor
            inputs, shape (number of splines, number of samples)
        grid : 2D torch.tensor
            grids, shape (number of splines, number of grid points)
        k : int
            the piecewise polynomial order of splines.
        extend : bool
            If True, k points are extended on both ends. If False, no extension (zero boundary condition). Default: True
        device : str
            devicde
    
    Returns:
    --------
        spline values : 3D torch.tensor
            shape (batch, in_dim, G+k). G: the number of grid intervals, k: spline order.
      
    Example
    -------
    >>> from kan.spline import B_batch
    >>> x = torch.rand(100,2)
    >>> grid = torch.linspace(-1,1,steps=11)[None, :].expand(2, 11)
    >>> B_batch(x, grid, k=3).shape
    '''
    
    x = x.unsqueeze(dim=2)
    grid = grid.unsqueeze(dim=0)
    
    if k == 0:
        value = (x >= grid[:, :, :-1]) * (x < grid[:, :, 1:])
    else:
        B_km1 = B_batch(x[:,:,0], grid=grid[0], k=k - 1)
        
        value = (x - grid[:, :, :-(k + 1)]) / (grid[:, :, k:-1] - grid[:, :, :-(k + 1)]) * B_km1[:, :, :-1] + (
                    grid[:, :, k + 1:] - x) / (grid[:, :, k + 1:] - grid[:, :, 1:(-k)]) * B_km1[:, :, 1:]
    
    # in case grid is degenerate
    value = torch.nan_to_num(value)
    return value

这段代码定义了一个名为 B_batch 的函数,用于在 B-spline 基础上评估输入 x。B-spline 是一种常用的样条函数,用于数据拟合和插值。

参数说明:

  • x:一个形状为 (number of splines, number of samples) 的 2D PyTorch 张量,表示输入数据。
  • grid:一个形状为 (number of splines, number of grid points) 的 2D PyTorch 张量,表示网格点。
  • k:整数,表示样条函数的分段多项式阶数,默认值为 0。
  • extend:布尔值,表示是否在两端扩展 k 个点,默认值为 True
  • device:字符串,表示设备类型,默认值为 'cpu'

函数返回一个形状为 (batch, in_dim, G+k) 的 3D PyTorch 张量,其中 G 是网格间隔数,k 是样条函数的阶数。

函数实现步骤:

  1. 输入参数处理

    • 将输入 x 和网格 grid 进行维度扩展,以便进行计算。
  2. 阶数为 0 的情况

    • 当 k 为 0 时,直接计算 x 与网格点之间的布尔值矩阵,表示 x 在网格点之间的取值情况。
  3. 阶数大于 0 的情况

    • 首先计算阶数为 k-1 的 B-spline 值 B_km1
    • 然后根据样条函数的定义,计算 x 在当前网格点区间内的 B-spline 值。这涉及到计算 x 与网格点的相对位置,以及使用 B_km1 的值进行线性组合。
  4. 处理异常情况

    • 使用 torch.nan_to_num 函数将任何计算结果中的 NaN 值替换为 0,以避免后续计算中的错误。

测试:

from kan.spline import *
import torch

# 创建一个随机输入张量 x 和网格张量 grid
x = torch.rand(100, 2)
grid = torch.linspace(-1, 1, steps=10)[None, :].expand(2, 10)

# 调用 B_batch 函数
B_values = B_batch(x, grid, k=3)
print(B_values.shape)
print(B_values)

torch.Size([100, 2, 6])
tensor([[[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          1.9858e-04],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 4.6404e-02, 5.6714e-01,
          3.7949e-01]],

        [[0.0000e+00, 0.0000e+00, 0.0000e+00, 2.2020e-02, 4.8497e-01,
          4.7332e-01],
         [0.0000e+00, 5.7094e-03, 3.6466e-01, 5.7832e-01, 5.1309e-02,
          0.0000e+00]],

        [[0.0000e+00, 1.7167e-04, 2.2175e-01, 6.5698e-01, 1.2110e-01,
          0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 9.4664e-03, 4.0435e-01,
          5.4731e-01]],

        ...,

        [[0.0000e+00, 0.0000e+00, 0.0000e+00, 1.2196e-01, 6.5738e-01,
          2.2050e-01],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 6.1864e-03, 3.7053e-01,
          5.7395e-01]],

        [[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          9.4280e-02],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          4.5621e-04]],

        [[0.0000e+00, 1.4357e-03, 2.8587e-01, 6.2895e-01, 8.3745e-02,
          0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 3.3225e-02,
          5.2970e-01]]])

3.2 coef2curve

def coef2curve(x_eval, grid, coef, k, device="cpu"):
    '''
    converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves (summing up B_batch results over B-spline basis).
    
    Args:
    -----
        x_eval : 2D torch.tensor
            shape (batch, in_dim)
        grid : 2D torch.tensor
            shape (in_dim, G+2k). G: the number of grid intervals; k: spline order.
        coef : 3D torch.tensor
            shape (in_dim, out_dim, G+k)
        k : int
            the piecewise polynomial order of splines.
        device : str
            devicde
        
    Returns:
    --------
        y_eval : 3D torch.tensor
            shape (number of samples, in_dim, out_dim)
        
    '''
    
    b_splines = B_batch(x_eval, grid, k=k)
    y_eval = torch.einsum('ijk,jlk->ijl', b_splines, coef.to(b_splines.device))
    
    return y_eval

这个函数 coef2curve 的目的是将 B-spline 系数转换为 B-spline 曲线。它通过在 B-spline 基础上对 B-spline 结果求和来实现这一点。以下是函数的详细说明:

参数说明:

  • x_eval:一个形状为 (batch, in_dim) 的 2D PyTorch 张量,表示要评估的点的集合。
  • grid:一个形状为 (in_dim, G+2k) 的 2D PyTorch 张量,其中 G 是网格间隔数,k 是样条函数的阶数。这个网格用于定义 B-spline 的基础。
  • coef:一个形状为 (in_dim, out_dim, G+k) 的 3D PyTorch 张量,表示 B-spline 系数。out_dim 是输出空间的维度。
  • k:整数,表示 B-spline 的阶数。
  • device:字符串,表示设备类型,默认为 "cpu"

函数实现步骤:

  1. 计算 B-spline 基础

    • 使用 B_batch 函数计算输入 x_eval 在 B-spline 基础上的值,这会得到一个形状为 (batch, in_dim, G+k) 的张量 b_splines
  2. 计算 B-spline 曲线

    • 使用 torch.einsum 函数对 b_splines 和 coef 进行张量外积操作,得到 B-spline 曲线的评估结果。torch.einsum 的 'ijk,jlk->ijl' 指定了如何对这两个张量进行操作,最终得到一个形状为 (batch, in_dim, out_dim) 的张量 y_eval
  3. 返回结果

    • 函数返回形状为 (number of samples, in_dim, out_dim) 的张量 y_eval,它表示在输入 x_eval 点上的 B-spline 曲线评估结果。

 测试:

# 测试数据
x_eval = torch.rand(10, 2)  # 假设有 10 个样本,每个样本 2 维
grid = torch.linspace(-1, 1, steps=20)[None, :].reshape(2, 10)  # 创建一个线性网格
coef = torch.rand(2, 2, 6)  # 假设有 2 个输出维度和 6 个 B 样条系数
k = 3  # B 样条阶数
device = "cpu"  # 设备

# 调用 coef2curve 函数
y_eval = coef2curve(x_eval, grid, coef, k, device)

# 打印输出
print("y_eval.shape:", y_eval.shape)
print("y_eval[:5, :, :2]:", y_eval[:5, :, :2])  # 打印前 5 个样本的前 2 个输出维度的值

y_eval.shape: torch.Size([10, 2, 2])
y_eval[:5, :, :2]: tensor([[[0.2986, 0.3585],
         [0.5725, 0.2797]],

        [[0.2516, 0.2835],
         [0.4817, 0.4179]],

        [[0.0160, 0.0336],
         [0.4771, 0.4242]],

        [[0.3026, 0.3798],
         [0.3606, 0.2922]],

        [[0.3028, 0.3852],
         [0.6890, 0.1768]]])

3.3 curve2coef

def curve2coef(x_eval, y_eval, grid, k, lamb=1e-8):
    '''
    converting B-spline curves to B-spline coefficients using least squares.
    
    Args:
    -----
        x_eval : 2D torch.tensor
            shape (in_dim, out_dim, number of samples)
        y_eval : 2D torch.tensor
            shape (in_dim, out_dim, number of samples)
        grid : 2D torch.tensor
            shape (in_dim, grid+2*k)
        k : int
            spline order
        lamb : float
            regularized least square lambda
            
    Returns:
    --------
        coef : 3D torch.tensor
            shape (in_dim, out_dim, G+k)
    '''
    batch = x_eval.shape[0]
    in_dim = x_eval.shape[1]
    out_dim = y_eval.shape[2]
    n_coef = grid.shape[1] - k - 1
    mat = B_batch(x_eval, grid, k)
    mat = mat.permute(1,0,2)[:,None,:,:].expand(in_dim, out_dim, batch, n_coef)
    y_eval = y_eval.permute(1,2,0).unsqueeze(dim=3)
    device = mat.device
    
    #coef = torch.linalg.lstsq(mat, y_eval,
                             #driver='gelsy' if device == 'cpu' else 'gels').solution[:,:,:,0]
        
    XtX = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), mat)
    Xty = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), y_eval)
    n1, n2, n = XtX.shape[0], XtX.shape[1], XtX.shape[2]
    identity = torch.eye(n,n)[None, None, :, :].expand(n1, n2, n, n).to(device)
    A = XtX + lamb * identity
    B = Xty
    coef = (A.pinverse() @ B)[:,:,:,0]
    
    return coef

这段代码是一个用于将B样条曲线转换为B样条系数的函数,使用了最小二乘法。

先吐个槽:作者把注释写错了,就是下面这两句

        x_eval : 2D torch.tensor
            shape (in_dim, out_dim, number of samples)
        y_eval : 2D torch.tensor
            shape (in_dim, out_dim, number of samples)

这两个参数形状应该与上面两个函数中涉及的形状一致,以及那个变量batch就是number of samples。看了好一会儿,怎么看怎么不对劲。

输入参数:

  • x_eval:一个二维的PyTorch张量,形状为(number_of_samples, in_dim),表示在评估点上的x坐标。
  • y_eval:一个二维的PyTorch张量,形状为(number_of_samples, in_dim, out_dim),表示在评估点上的y坐标。
  • grid:一个二维的PyTorch张量,形状为(in_dim, grid + 2 * k),表示控制点网格。
  • k:一个整数,表示B样条的阶数。
  • lamb:一个浮点数,用于正则化最小二乘法的lambda参数,默认值为1e-8

返回值

  • coef:一个三维的PyTorch张量,形状为(in_dim, out_dim, G + k),表示B样条曲线的系数。

代码解释

  1. 维度信息获取

    • batch:输入数据的批次大小。
    • in_dim:输入数据的维度。
    • out_dim:输出数据的维度。
    • n_coef:B样条系数的数量,计算为grid.shape[1] - k - 1
  2. 矩阵计算

    • mat:通过调用B_batch函数计算得到的矩阵,该矩阵包含了B样条曲线的基函数值。
  3. 张量操作

    • mat.permute(1,0,2):将mat的维度重新排列。
    • mat[:,None,:,:].expand(in_dim, out_dim, batch, n_coef):扩展mat以匹配y_eval的形状。
    • y_eval.permute(1,2,0).unsqueeze(dim=3):重新排列y_eval的维度并增加一个批次维度。
  4. 设备检测

    • device:检测mat所在的设备(CPU或GPU)。
  5. 最小二乘法

    • XtX:计算矩阵mat的转置与自身的乘积。
    • Xty:计算矩阵mat的转置与y_eval的乘积。
    • identity:创建一个单位矩阵,并将其扩展以匹配XtX的形状。
    • A:将XtXlamb乘以的单位矩阵相加,用于正则化。
    • BXty
  6. 求解线性方程组

    • (A.pinverse() @ B)[:,:,:,0]:使用A的伪逆乘以B来求解线性方程组,并取结果的前n_coef个元素。
  7. 返回结果

    • 函数返回计算得到的B样条系数coef

 测试:

# 生成随机x_eval和y_eval数据
x_eval = torch.randn(10, 2)  # 1个批次,2个输入维度,10个样本点
y_eval = torch.randn(10, 2, 2)  # 1个批次,2个输出维度,10个样本点
grid = torch.randn(2, 10)       # 1个输入维度,10个节点

k = 3
lamb=1e-8

coef = curve2coef(x_eval, y_eval, grid, k, lamb)
print("coef.shape:", coef.shape)
print("coef[:5, :, :2]:", coef[:5, :, :2]) 

 coef.shape: torch.Size([2, 2, 6])
coef: tensor([[[ 2.5919e-03, -3.9341e-02,  3.7596e+00,  6.2739e+00,  1.5958e+00,
           0.0000e+00],
         [ 2.0447e-03,  3.1825e-01,  8.1563e-01,  6.2472e-01, -3.4195e-01,
           0.0000e+00]],

        [[ 7.4711e-02, -3.1942e+00,  4.7249e+00,  6.0009e+00, -2.6232e+00,
          -3.6251e+00],
         [-2.0716e-01,  6.7293e+00, -9.1516e+00, -1.1859e+01,  1.0950e+00,
           5.0081e-01]]])

3.4 extend_grid

def extend_grid(grid, k_extend=0):
    '''
    extend grid
    '''
    h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1)

    for i in range(k_extend):
        grid = torch.cat([grid[:, [0]] - h, grid], dim=1)
        grid = torch.cat([grid, grid[:, [-1]] + h], dim=1)

    return grid

extend_grid函数用于扩展给定的网格

  1. 计算网格间距:首先,函数计算网格中相邻两个点之间的平均距离h。这通过从最后一个网格点到第一个网格点的差除以网格点的数量(减去1)来实现。这样计算得到的h代表了网格的平均间隔。

  2. 扩展网格:函数通过循环k_extend次来扩展网格。在每次迭代中,它执行以下操作:

    • 左扩展:将新的网格点添加到网格的左侧,这个点的值是当前网格的左侧点减去h。这样,网格在左侧增加了h的长度。
    • 右扩展:将新的网格点添加到网格的右侧,这个点的值是当前网格的右侧点加上h。这样,网格在右侧同样增加了h的长度。
  3. 返回扩展后的网格:经过k_extend次扩展后,函数返回扩展后的网格。

 测试:

# 创建一个简单的网格
grid = torch.tensor([[0, 1], [2, 3], [4, 5]], dtype=torch.float32)

# 假设我们要向网格的左右两侧各扩展1个单位长度
k_extend = 2
    
# 使用extend_grid函数扩展网格
extended_grid = extend_grid(grid, k_extend)

print(grid)
print(extended_grid)
    

 tensor([[0., 1.],
             [2., 3.],
             [4., 5.]])
tensor([[-2., -1.,  0.,  1.,  2.,  3.],
        [ 0.,  1.,  2.,  3.,  4.,  5.],
        [ 2.,  3.,  4.,  5.,  6.,  7.]])

四、总结 

utils.py 包含了一系列用于处理符号函数、生成数据集、拟合参数、创建稀疏掩码、添加符号函数到库、四舍五入表达式、增强输入、计算模型参数的Jacobian和Hessian矩阵,以及从数据中创建训练和测试数据集的功能。

  1. 符号函数处理f_invf_inv2f_inv3f_inv4f_inv5f_sqrtf_power1d5f_invsqrtf_logf_tanf_arctanhf_arcsinf_arccosf_exp:这些函数是用于处理不同数学操作的lambda函数,它们通常用于在符号计算中生成表达式。
  2. 创建数据集create_dataset这个函数用于生成一个数据集,其中包含了输入和对应的标签。它接受一个函数f,该函数定义了数据集的生成方式,并允许用户指定变量数量、范围、训练和测试样本数量等。
  3. 拟合参数fit_params这个函数用于拟合参数abcd,使得表达式|y-(cf(ax+b)+d)|^2最小化。它通过扫描参数ab的值来找到最佳拟合。
  4. 创建稀疏掩码sparse_mask这个函数用于创建一个稀疏掩码,它定义了输入和输出维度之间的连接。
  5. 添加符号函数到库add_symbolic:这个函数允许用户将自定义的符号函数添加到SYMBOLIC_LIB库中。
  6. 四舍五入表达式ex_round这个函数用于将符号表达式中浮点数的精度四舍五入到指定的小数位数。
  7. 增强输入augment_input这个函数用于增加输入数据的维度,通过添加辅助变量来增强原始变量。
  8. 雅可比矩阵计算batch_jacobian:这个函数计算给定函数 func 在输入 x 上的全梯度(Jacobian)。它首先将 func 的输出求和,然后使用 torch.autograd.functional.jacobian 函数计算该和函数对 x 的全梯度。
  9. 海森矩阵计算batch_hessian:这个函数计算给定函数 func 在输入 x 上的Hessian矩阵。它首先计算 func 的Jacobian,然后计算该Jacobian对 x 的全梯度,即Hessian矩阵。
  10. 划分数据集create_dataset_from_data这个函数从输入数据 inputs 和标签 labels 中创建训练集和测试集。它首先随机选择一部分数据作为训练集,其余作为测试集,并将这些数据转换为所需的格式。
  11. 导数计算get_derivative:这个函数计算损失函数关于模型参数的导数(Jacobian 或 Hessian)。它首先将模型参数从模型结构转换为一个扁平化的向量,然后定义一个函数 param2loss_fun,该函数将参数向量转换回模型结构,并计算损失。最后,它使用 batch_jacobian 或 batch_hessian 函数计算损失对参数的导数。
  12. 转化模型结构model2param这个函数将模型参数从模型结构转换为一个扁平化的向量,便于后续操作如梯度计算等。

spline.py 提供了处理B样条曲线的函数,包括计算B样条基、从B样条系数到曲线的转换,以及从曲线到B样条系数的转换。

  • B_batch:这个函数计算输入点在B样条基上的值。
    • 接收输入点 x、网格 grid、B样条阶数 k、是否扩展网格 extend 和设备 device
    • 通过递归调用自身来计算高阶B样条基的值。
    • 返回输入点在B样条基上的值。
  • coef2curve:这个函数将B样条系数转换为B样条曲线。
    • 接收评估点 x_eval、网格 grid、系数 coef、阶数 k 和设备 device
    • 使用 B_batch 函数计算B样条基。
    • 通过求和B样条基与系数的乘积来得到B样条曲线的值。
    • 返回B样条曲线的值。
  • curve2coef:这个函数将B样条曲线转换为B样条系数。
    • 接收评估点 x_eval、曲线值 y_eval、网格 grid、阶数 k 和正则化系数 lamb
    • 使用 B_batch 函数计算B样条基。
    • 通过最小二乘法求解系数,其中考虑了正则化项以避免过拟合。
    • 返回B样条系数。
  • extend_grid:这个函数扩展B样条网格。
    • 接收网格 grid 和扩展阶数 k_extend
    • 在网格的两端添加额外的点,以扩展网格。
    • 返回扩展后的网格。

 

  • 12
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值