纯加法Transformer!结合SNN和Transformer的Spike-driven Transformer

这里提出一种Spike-driven Transformer模型,首次将spike-driven计算范式融入Transformer。本文所提出的SDSA算子能耗比原始self-attention的能耗低87.2倍。所提出的Spike-driven Transformer在ImageNet-1K上取得了77.1%的SNN领域内SOTA结果。

论文地址:https://arxiv.org/abs/2307.01694

代码地址:https://github.com/BICLab/Spike-Driven-Transformer

受益于基于二进制脉冲信号的事件驱动(Spike-based event-driven,Spike-driven)计算特性,脉冲神经网络(Spiking Neural Network,SNN)提供了一种低能耗的深度学习选项 [1]。本文提出一种Spike-driven Transformer模型,首次将spike-driven计算范式融入Transformer。整个网络中只有稀疏加法运算。具体地,所提出的Spike-driven Transformer具有四个独特性质:

  • 事件驱动(Event-driven)。网络输入为0时,不会触发计算。
  • 二进制脉冲通信(Binary spike communication)。所有与脉冲张量相关的矩阵乘法都可以转化为稀疏加法。
  • 脉冲驱动自注意力(Spike-Driven Self-Attention,SDSA)算子。脉冲形式Q,K,V矩阵之间运算为掩码(mask)和加法。
  • 线性注意力(Linear attention)。SDSA算子的计算复杂度与token和channel都为线性关系。

本文所提出的SDSA算子能耗比原始self-attention的能耗低87.2倍。所提出的Spike-driven Transformer在ImageNet-1K上取得了77.1%的SNN领域内SOTA结果。

当前SNN模型的任务性能较低,难以满足实际任务场景中的精度要求。如何结合Transformer模型的高性能和SNN的低能耗,是目前SNN域内的研究热点。现有的spiking Transformer模型可以简单地被认为是异构计算模型,也就是将SNN中的脉冲神经元和Transformer模型中的一些计算单元(例如:dot-product, softmax, scale)相结合,既有乘加运算(Multiply-and-ACcumulate,MAC),也有加法运算(ACcumulate,AC)。虽然能保持较好的任务精度,但不能完全发挥出SNN的低能耗优势。

近期的一项工作,SpikFormer[2],展示了在spiking self-attention中,softmax操作是可以去掉的。然而,SpikFormer中保留了spiking self-attention中的scale操作。原因在于,脉冲形式Q,K,V矩阵之间运算会导致输出中会包含一些数值较大的整数,为避免梯度消失,SpikFormer保留了scale操作(乘法)。另一方面,SpikFormer采用Spike-Element-Wise(SEW)[3]的残差连接,也就是,在不同层的脉冲神经元输出之间建立shortcut。这导致与权重矩阵进行乘法操作的脉冲张量实际上是多bit脉冲(整数)。因此,严格来说,SpikFormer是一种整数驱动Transformer(Integer-driven Transformer),而不是脉冲驱动Transformer。

方法

本文提出了Spike-driven Transformer,如下图所示,以SpikFormer[2]中的模型为基础,做出两点关键改进:

  • 提出一种脉冲驱动自注意力(SDSA)算子。目前SNN领域中仅有Spike-driven Conv和spike-driven MLP两类算子。本文所提出的Spike-driven Self-attention算子,为SNN领域提供了一类新算子
  • 调整shortcut。将网络中的SEW全部调整为Membrane Shortcut(MS)[4,5],也就是在不同层的脉冲神经元膜电势之间建立残差连接。

Spike-Driven-Transformer_Self

SDSA算子。ANN中的原始自注意力(Vanilla Self-Attention,VSA)机制的表达式为: 

Spike-Driven-Transformer_矩阵乘法_02

总体来说,SDSA算子有两个特点:

Spike-Driven-Transformer_矩阵乘法_03

MS残差连接。目前SNN领域中一共有三种残差连接。一种是直接参考ResNet的Vanilla Shortcut [6],在不同层的膜电势和脉冲之间建立捷径;一种是SEW [3],在不同层的脉冲之间建立捷径;一种是MS [4],在不同层的膜电势之间建立捷径。MS连接之后会跟随一个脉冲神经元,这可以将膜电势之和转化为0/1,从而保证网络中所有脉冲张量与权重矩阵之间的乘法可以被转换为加法。因此,本文使用MS残差来保证spike-driven。 

SNN中的算子及其能耗评估

Spike-driven的核心是,与脉冲矩阵相关的乘法运算都可以被转换为稀疏加法。当SNN运行在神经形态芯片上时,spike-driven计算范式能够发挥出低能耗优势。

Spike-driven Conv和Spike-driven MLP。脉冲驱动计算有两层含义:事件驱动二进制脉冲通信。前者保证了输入为0时,不会触发计算;后者保证了有脉冲输入时,触发的计算为加法。当前SNN领域中,两类典型的算子是spike-driven Conv和spike-driven MLP。在进行矩阵乘法时,如果其中一个矩阵是脉冲形式,那么矩阵乘法可以通过寻址算法被转换为加法。

Spike-Driven-Transformer_矩阵乘法_04

结果

Spike-driven Transformer在ImageNet上的结果如下所示。本文取得了SNN域的SOTA结果。

Spike-Driven-Transformer_矩阵乘法_05

Spike-Driven-Transformer_矩阵乘法_06

Spike-Driven-Transformer_Self_07

全文到此结束,更多细节建议查看原文。本文所有代码和模型均已开源,欢迎关注我们的工作。