Efficient-KAN源码链接
改进细节
1.内存效率提升
KAN网络的原始实现的性能问题主要在于它需要扩展所有中间变量以执行不同的激活函数。对于具有in_features个输入和out_features个输出的层,原始实现需要将输入扩展为shape为(batch_size, out_features, in_features)的tensor以执行激活函数。然而,所有激活函数都是一组固定基函数(3阶B样条)的线性组合。鉴于此,拟将计算重新表述为不同的基函数激活输入,然后将它们线性组合。这种重新表述可以显著减少内存消耗,并使计算变得更加简单的矩阵乘法,自然地适用于前向和后向传递。
2.正则化方法的改变
稀疏化被认为对KAN的可解释性至关重要。作者提出了一种定义在输入样本上的L1正则化,它需要对**(batch_size, out_features, in_features)** tensor进行非线性操作,因此与重新表述不兼容。拟改为对权重进行L1正则化,这在NN中更为常见,并且与重新表述兼容。
3.激活函数缩放选项
除了可学习的激活函数(B样条),原始实现还包括对每个激活函数的可学习缩放 ( w s ) (w_s) (ws)。拟提供一个名为enable_standalone_scale_spline的选项,默认情况下为True,以包含此功能。禁用它会使模型更高效,但可能会影响结果。这需要更多实验验证。
4.参数初始化的改变
为了解决在MNIST数据集上的性能问题,该代码修改了参数的初始化方式,采用Kaiming初始化。
KAN_fast.py解析
基本参数和类定义
import torch
import torch.nn.functional as F
import math
class KANLinear(torch.nn.Module):
def __init__(
self,
in_features,
out_features,
grid_size=5, # 网格大小,默认为 5
spline_order=3, # 分段多项式的阶数,默认为 3
scale_noise=0.1, # 缩放噪声,默认为 0.1
scale_base=1.0, # 基础缩放,默认为 1.0
scale_spline=1.0, # 分段多项式的缩放,默认为 1.0
enable_standalone_scale_spline=True,
base_activation=torch.nn.SiLU, # 基础激活函数,默认为 SiLU(Sigmoid Linear Unit)
grid_eps=0.02,
grid_range=[-1, 1], # 网格范围,默认为 [-1, 1]
):
super(KANLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.grid_size = grid_size # 设置网格大小和分段多项式的阶数
self.spline_order = spline_order
h = (grid_range[1] - grid_range[0]) / grid_size # 计算网格步长
生成网格
grid = ( # 生成网格
(
torch.arange(-spline_order, grid_size + spline_order + 1) * h
+ grid_range[0]
)
.expand(in_features, -1)
.contiguous()
)
self.register_buffer("grid", grid) # 将网格作为缓冲区注册
1. torch.arange(-spline_order, grid_size + spline_order + 1)
**torch.arange(start, end)**
:生成一个从start
到end-1
的整数序列(左闭右开区间)。**-spline_order**
:从负的spline_order
开始。**grid_size + spline_order + 1**
:终止于grid_size + spline_order
(不包括+1
)。
这个序列的长度是 grid_size + 2 * spline_order + 1
,用于涵盖所有需要的网格点,包括两端的扩展区域。
2. * h
- 这一步将生成的整数序列乘以步长
h
,将索引序列转换为实际的网格位置。
3. + grid_range[0]
- 这一步将整个网格位置进行平移,使得网格的起始点与
grid_range[0]
对齐。
如果 grid_range[0] = -1
,则每个位置都会减去 1
4. .expand(in_features, -1)
**.expand()**
:将这个网格复制in_features
次,以适应输入特征的维度。具体来说,它将原本的一维网格向量扩展成一个in_features
×(grid_size + 2 * spline_order + 1)
的二维张量。 其中每一行都是相同的网格向量。
5. .contiguous()
**.contiguous()**
:确保扩展后的张量在内存中是连续存储的,方便后续的计算和操作。虽然在大多数情况下这个操作是可选的,但它可以提高计算效率并避免潜在的问题。
最终效果:
这段代码生成了一个二维张量 grid
,它的形状为 [in_features, grid_size + 2 * spline_order + 1]
,其中每一行都是相同的、覆盖整个 grid_range
并适当扩展的网格点序列。这个网格用于模型中的 B 样条或其他基函数计算,使得模型可以在输入数据范围内执行灵活的插值和拟合操作。
初始化可训练参数
self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features)) # 初始化基础权重和分段多项式权重
self.spline_weight = torch.nn.Parameter(
torch.Tensor(out_features, in_features, grid_size + spline_order)
)
if enable_standalone_scale_spline: # 如果启用独立的分段多项式缩放,则初始化分段多项式缩放参数
self.spline_scaler = torch.nn.Parameter(
torch.Tensor(out_features, in_features)
)
1. self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
( w b ) (w_b) (wb)
**torch.Tensor(out_features, in_features)**
:创建一个形状为(out_features, in_features)
的未初始化张量,用于存储基础线性层的权重。这个张量的元素初始时没有具体的数值,通常在后续的reset_parameters()
方法中进行初始化。**torch.nn.Parameter**
:将这个张量封装成torch.nn.Parameter
对象。这意味着这个张量会被视为模型的可训练参数,PyTorch 会自动将其包含在模型的参数列表中,并在反向传播时更新其值。**self.base_weight**
:这个属性存储的是基础线性变换的权重矩阵。这个矩阵将在前向传播过程中被用来对输入特征进行线性变换。
2. self.spline_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features, grid_size + spline_order))
( c i ) (c_i) (ci)
**torch.Tensor(out_features, in_features, grid_size + spline_order)**
:创建一个形状为(out_features, in_features, grid_size + spline_order)
的未初始化张量,用于存储分段多项式的权重。这些权重将用于 B 样条或其他类似方法的计算。**torch.nn.Parameter**
:同样地,将这个张量封装成torch.nn.Parameter
,使其成为模型的可训练参数。**self.spline_weight**
:这个属性存储的是与分段多项式相关的权重。这些权重决定了如何将输入特征映射到输出特征,特别是在使用 B 样条等非线性激活函数时。
为什么 spline_weight
的形状是 (out_features, in_features, grid_size + spline_order)
?
**out_features**
和**in_features**
:与base_weight
类似,表示输出和输入的特征数量。**grid_size + spline_order**
:这个维度表示在 B 样条或其他分段多项式方法中,每个输入特征需要使用的基函数的数量。通过这些基函数的线性组合,可以生成灵活的非线性激活。
3. if enable_standalone_scale_spline:
- 这个条件语句检查
enable_standalone_scale_spline
是否为True
。如果为True
,则会为每个分段多项式激活函数引入一个独立的缩放参数。
4. self.spline_scaler = torch.nn.Parameter(torch.Tensor(out_features, in_features))
( w s ) (w_s) (ws)
**torch.Tensor(out_features, in_features)**
:创建一个形状为(out_features, in_features)
的张量,用于存储独立的分段多项式缩放参数。**torch.nn.Parameter**
:将张量封装成torch.nn.Parameter
,使其成为可训练参数。**self.spline_scaler**
:这个属性存储的是分段多项式的缩放参数。每个spline_weight
都有一个对应的缩放参数,可以单独调整其幅度,从而提供更大的灵活性。
其他实例属性
self.scale_noise = scale_noise # 保存缩放噪声、基础缩放、分段多项式的缩放、是否启用独立的分段多项式缩放、基础激活函数和网格范围的容差
self.scale_base = scale_base
self.scale_spline = scale_spline
self.enable_standalone_scale_spline = enable_standalone_scale_spline
self.base_activation = base_activation()
self.grid_eps = grid_eps
self.reset_parameters() # 重置参数
Kaiming初始化权重(reset_parameters)
def reset_parameters(self):
torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)# 使用 Kaiming 均匀初始化基础权重
with torch.no_grad():
noise = (# 生成缩放噪声
(
torch.rand(self.grid_size + 1, self