关于KAN(Kolmogorov-Arnold Networks)的小小调研

1 篇文章 0 订阅

前言

KAN是MIT2024年5月提出,与传统MLP(多层感知机)并列的全新深度学习架构。截至2024.9.13已有150+次引用。笔者是该方向的入门新手,开此贴记录相关方向的论文调研笔记~(欢迎指正和补充

KAN 1.0 原文链接
KAN 2.0 原文链接

相关研究

basis function改进类

(KAN使用B样条拟合每个node的一元非线性函数)

《Kolmogorov-arnold networks are radial basis function networks》
  • replaces the B-spline bases calculations with Gaussian RBF(径向基函数)

  • 开源代码:Code

与其他DL场景结合类

联邦学习 《F-KANs: Federated Kolmogorov-Arnold Networks》

  • 每个client训练局部KAN,中心server集成模型
  • 实验(代码link:Code
    • 2 clients
    • dataset: Iris 数据集链接
    • summary: KAN model is not only accurate but also consistent in
      its predictions, capturing all true positives (300s vs 2s/ KAN100% accuracy)
    • The spline-based univariate features enable KAN to capture
      complex patterns quickly and accurately, resulting in stable
      and high performance with fewer training rounds.

时间序列1 《KAN for Time Series Analysis》

  • KANs outperforms conventional MLP in a real-world satellite traffic forecasting task, providing more accurate results with considerably fewer number of parameters.

  • aim to evaluate the practicality of KANs in realworld scenarios, analyzing their efficiency in terms of the
    number of trainable parameters and discussing how the additional degrees of freedom might affect forecasting performance.

  • using realworld satellite traffic data.

  • 变量时间序列预测

  • 两层KAN,输入层和输出层的节点数分别对应total amount of time steps.

  • 基于卫星数据预测交通状况

  • 实验

    • 对比KAN与MLP架构的performance(6 beam area)
    • 用一周数据预测一天
    • KAN shows a rapid adjustment <> MLP exhibits a lag
    • KAN matched the rapid volume <> MLP moderately over/under-predicted
    • the robustness of KAN despite the complexity and higher volume <>
    • 参数方面:This reduced complexity suggests that KANs can achieve higher or comparable forecasting accuracy with simpler and potentially faster models.(用于资源有限和快速部署的场景)
      框架示意图

时间序列2 《Kolmogorov-Arnold Networks (KANs) for Time Series Analysis》

  • T-KAN:detect conecpt drift within time series(univariate)
    • core:use sliding window (two historical time steps to predict the next time step in the example, different KAN structures & activation functions represent different concepts) to observe the variations in KAN and identify concept drift.
      在这里插入图片描述
  • MT-KAN: imrove predictive performance (multivariate time series)
    在这里插入图片描述

时间序列3 KAN4TSF: Are KAN and KAN-based models Effective for Time Series Forecasting?

【持续更新中····】

Kolmogorov-Arnold Networks (KAN)是一种非线性的激活函数,它结合了多项式和三角函数的特点。将MLP(Multilayer Perceptron,多层感知机)中的激活单元替换为KAN,可以增加模型的复杂性和表达能力。 然而,在PyTorch中直接替换MLP的激活函数并不会那么简单,因为KAN通常不是现成的内置模块。你需要自定义激活函数并将其应用到卷积神经网络(1D CNN)中的序列到序列(seq2seq)模型中。这是一个简化的示例,说明如何自定义KAN激活并在Seq2Seq模型中使用: ```python import torch import torch.nn as nn class KANActivation(nn.Module): def __init__(self): super(KANActivation, self).__init__() # 自定义KAN的参数 self.kan_poly = nn.Parameter(torch.randn(1)) self.kan_trig = nn.Parameter(torch.randn(1)) def forward(self, x): return self.kan_poly * x + torch.sin(self.kan_trig * x) # 定义1D CNN Encoder部分 class CNNEncoder(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, padding): super(CNNEncoder, self).__init__() self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding) self.activation = KANActivation() def forward(self, x): conv_output = self.conv(x) activated_output = self.activation(conv_output) return activated_output # 定义LSTM Decoder部分 class LSTMDecoder(nn.Module): def __init__(self, input_dim, hidden_dim, num_layers, dropout): super(LSTMDecoder, self).__init__() self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, dropout=dropout) self.fc = nn.Linear(hidden_dim, input_dim) # 输出层,假设输入和输出维度一致 def forward(self, x, encoder_outputs): # ...这里处理编码器输出和解码器输入... output, _ = self.lstm(encoder_outputs) prediction = self.fc(output) return prediction # 整体的seq2seq模型 class Seq2SeqModel(nn.Module): def __init__(self, encoder_params, decoder_params): super(Seq2SeqModel, self).__init__() # 初始化CNN和LSTM部分 self.encoder = CNNEncoder(**encoder_params) self.decoder = LSTMDecoder(**decoder_params) def forward(self, src, trg): # 输入源序列,实际操作会更复杂,包括填充、截断等 encoder_outputs = self.encoder(src) # 解码器输入和前一个时间步的预测作为输入 # ... return decoder_output # 示例参数 params_encoder = {'in_channels': 1, 'out_channels': 64, 'kernel_size': 3, 'padding': 1} params_decoder = {'input_dim': 64, 'hidden_dim': 128, 'num_layers': 2, 'dropout': 0.5} model = Seq2SeqModel(params_encoder, params_decoder) ``` 注意这只是一个基本框架,实际应用中你还需要添加注意力机制(如自注意力或点对点注意力)、双向循环、以及训练相关的循环和损失函数。此外,KAN激活需要根据你的数据调整参数,以达到最佳性能。关于KAN的具体实现细节和应用,你可能需要查阅相关研究论文或参考其他库,如`torchdyn`,它可能提供现成的支持。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值