Transformer+KAN系列时间序列预测代码

前段时间,来自 MIT 等机构的研究者提出了一种非常有潜力的替代方法 ——KAN。该方法在准确性和可解释性方面表现优于 MLP。而且,它能以非常少的参数量胜过以更大参数量运行的 MLP。

KAN的发布,引起了AI社区大量的关注与讨论,同时也伴随很大的争议。
而此类研究,又有了新的进展。

最近,来自新加坡国立大学的研究者提出了 Kolmogorov–Arnold Transformer(KAT),用 Kolmogorov-Arnold Network(KAN)层取代 MLP 层,以增强模型的表达能力和性能。

c2f9e6ac7f317eaf3ed75a035c2a461d.png

  • 论文标题:Kolmogorov–Arnold Transformer

  • 论文地址:https://arxiv.org/pdf/2409.10594

  • 项目地址:https://github.com/Adamdad/kat

KAN 原论文第一作者 Ziming Liu 也转发点赞了这项新研究。

dd4d6ba596e5e198dc6a3fdd9d37d2f2.png

将 KAN 集成到 Transformer 中并不是一件容易的事,尤其是在扩展时。具体来说,该研究确定了三个关键挑战:

(C1) 基函数。KAN 中使用的标准 B 样条(B-spline)函数并未针对现代硬件上的并行计算进行优化,导致推理速度较慢。

(C2) 参数和计算效率低下。KAN 需要每个输入输出对都有特定的函数,这使得计算量非常大。

(C3) 权重初始化。由于具有可学习的激活函数,KAN 中的权重初始化特别具有挑战性,这对于实现深度神经网络的收敛至关重要。

为了克服上述挑战,研究团队提出了三个关键解决方案:

(S1) 有理基础。该研究用有理函数替换 B 样条函数,以提高与现代 GPU 的兼容性。通过在 CUDA 中实现这一点,该研究实现了更快的计算。

(S2) Group KAN。通过一组神经元共享激活权重,以在不影响性能的情况下减少计算负载。

(S3) Variance-preserving 初始化。该研究仔细初始化激活权重,以确保跨层保持激活方差。

结合解决方案 S1-S3,该研究提出了一种新的 KAN 变体,称为 Group-Rational KAN (GR-KAN),以取代 Transformer 中的 MLP。

实验结果表明:GR-KAN 计算效率高、易于实现,并且可以无缝集成到视觉 transformer(ViT)中,取代 MLP 层以实现卓越的性能。此外,该研究的设计允许 KAT 从 ViT 模型加载预训练权重并继续训练以获得更好的结果。

该研究在一系列视觉任务中实证验证了 KAT,包括图像识别、目标检测和语义分割。结果表明,KAT 的性能优于传统的基于 MLP 的 transformer,在计算量相当的情况下实现了增强的性能。

66f0270001f4439d3151c665d5abc0f1.png

如图 1 所示,KAT-B 在 ImageNet-1K 上实现了 82.3% 的准确率,超过相同大小的 ViT 模型 3.1%。当使用 ViT 的预训练权重进行初始化时,准确率进一步提高到 82.7%。

不过,也有网友质疑道:「自从有论文比较了具有相同参数大小的 MLP 模型和 KAN 模型的性能后,我就对 KAN 持怀疑态度。可解释性似乎是唯一得到巨大提升的东西。」

24d779652acb920dfc10265964516841.png

对此,论文作者回应道:「的确,原始 KAN 在可解释性上做得很好,但不保证性能和效率。我们所做的就是修复这些 bug 并进行扩展。」

90a5efb358cdd51d53c17fca5a3e15dc.png

还有网友表示,这篇论文和其他人的想法一样,就是用 KAN 取代了 MLP,并质疑为什么作者在尝试一些已经很成熟和类似的东西,难道是在炒作 KAN?对此, 论文作者 Xingyi Yang 解释道,事实确实如此,但不是炒作,根据实验,简单地进行这种替换是行不通的,他们在努力将这个简单的想法变成可能的事情。

ff9fcbcad7febb59951e1587947d49f8.png

 Kolmogorov–Arnold Transformer (KAT)

作者表示,标准的 KAN 面临三大挑战,限制了其在大型深度神经网络中的应用。

它们分别是基函数的选择、冗余参数及其计算、初始化问题。这些设计选择使得原始版本的 KAN 是资源密集型的,难以应用于大规模模型。

本文对这些缺陷设计加以改进,以更好地适应现代 Transformer,从而允许用 KAN 替换 MLP 层。

源码地址及其详细讲解(免费)

https://space.bilibili.com/51422950?spm_id_from=333.1007.0.0

### Transformer在图像分类中的方法 Transformer架构最初设计用于自然语言处理任务,但其强大的建模能力使其逐渐被应用于计算机视觉领域。对于图像分类而言,一种有效的方式是将图像分割成多个小块(patch),并将这些小块视为序列数据输入到Transformer中[^1]。 具体来说,每张图片会被切分成固定大小的小方块(例如16×16像素)。之后,通过线性映射将每个图像块转换为一维向量作为Token表示形式。为了使模型能够理解不同位置的信息,在送入Transformer之前还需要加上可学习的位置编码。 #### 实现过程概述 以下是基于上述原理构建的一个简单版本的Vision Transformer (ViT) 的Python代码片段: ```python import torch from torch import nn, optim import torchvision.transforms as transforms from PIL import Image class PatchEmbedding(nn.Module): """ 将图像划分为若干个patches并嵌入 """ def __init__(self, img_size=224, patch_size=16, embed_dim=768): super().__init__() self.img_size = img_size self.patch_size = patch_size num_patches = (img_size // patch_size)**2 self.projection = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self,x): x=self.projection(x).flatten(2).transpose(1,2) return x def get_positional_encoding(seq_len,d_model): pe=torch.zeros((seq_len,d_model)) position=torch.arange(0, seq_len).unsqueeze(1) div_term=(torch.exp(torch.arange(0., d_model, 2)*-(math.log(10000.)/d_model))) pe[:,::2]=torch.sin(position*div_term) pe[:,1::2]=torch.cos(position*div_term) return pe.unsqueeze(0) class VisionTransformer(nn.Module): def __init__(self,...): #省略其他参数定义 ... self.patch_embedding = PatchEmbedding(img_size=img_size, patch_size=patch_size, embed_dim=d_model) self.pos_encoder=get_positional_encoding(num_tokens,d_model) ... def forward(self, src): patches = self.patch_embedding(src) out = patches + self.pos_encoder[:,:patches.size(1),:] ... ``` 这段代码展示了如何创建一个基本的Patch Embedding层以及Position Encoding机制来准备输入给后续Transformer Encoder部分的数据流。实际应用时还可以加入Class Token以便于最终预测类别标签,并且可以堆叠多层Encoder以增强特征提取效果。 针对更先进的Swin Transformer结构,则采用了分层化的设计思路,即每一阶段都会缩小空间分辨率的同时增加通道数,从而更好地捕捉局部细节与全局上下文之间的关系。此外,还引入了移位窗口划分策略使得相邻两层之间存在重叠区域,进一步提升了性能表现[^2]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值