仅个人学习记录和个人拙见,有错误地方希望指导~
原文:Spike-driven Transformer
推荐阅读:纯加法Transformer!结合脉冲神经网络和Transformer的Spike-driven Transformer (NeurIPS 2023)
简单叙述
对transformer的整体框架转换如下图。因为我主要阅读的是Spike-Driven Self-Attention(SDSA)部分,所以对整体框架的解读就少了许多,下面只是简单介绍一下attention部分的qkv的改变。
SDSA的改进如下。也就是纯加法transformer的过程中对注意力机制的变化。图片中很形象,能够一目了然。主要是将矩阵乘积转换为了Hadamard multiply哈达玛积。还引用了点积等。在计算复杂度上也有了一定量的减少。
注意力部分代码解读
注意力部分实现在./module/ms_conv.py里,主要函数为MS_SSA_Conv()类。
主要看一下forward中对整个过程的改写。
(代码中dvs和hook部分忽略阅读,不影响整体框架。dvs是说输入是否是动态视频,hook是是否保存一些过程值)
前提:
x:输入
T:序列长度
B:batch长度,批大小
C:通道数
H:长度(行)
W:宽度(列)
x的形状是(T,B,C,H,W)
N:像素个数(N=H*W)
我们为了便于理解将T, B, C, H, W分别假设为1, 2, 4, 3, 3,num_heads=2
准备工作
1、首先复制一个原始x为identity,并初始化其他参数。
T, B, C, H, W = x.shape
identity = x
N = H * W
identity = x
2、将x变为脉冲信号
x = self.shortcut_lif(x)
3、新建一个变量x_for_qkv,存储x扁平化的结果。为了方便后续qkv的操作。
x_for_qkv = x.flatten(0, 1)
qkv初始化
4、得到q_conv_out(也就是注意力机制中的q,为了方便后续就将其叫为q)。首先将x_for_qkv进行一个二维卷积,然后再进行一个批归一化。并在同时reshape一下使得q的形状为(T, B, C, H, W)。并将q 变为脉冲形式。
在上述假设中此时的q为脉冲,并且形状为