仅个人学习记录和个人拙见,有错误地方希望指导~
原文:Spike-driven Transformer
推荐阅读:纯加法Transformer!结合脉冲神经网络和Transformer的Spike-driven Transformer (NeurIPS 2023)
简单叙述
对transformer的整体框架转换如下图。因为我主要阅读的是Spike-Driven Self-Attention(SDSA)部分,所以对整体框架的解读就少了许多,下面只是简单介绍一下attention部分的qkv的改变。
SDSA的改进如下。也就是纯加法transformer的过程中对注意力机制的变化。图片中很形象,能够一目了然。主要是将矩阵乘积转换为了Hadamard multiply哈达玛积。还引用了点积等。在计算复杂度上也有了一定量的减少。
注意力部分代码解读
注意力部分实现在./module/ms_conv.py里,主要函数为MS_SSA_Conv()类。
主要看一下forward中对整个过程的改写。
(代码中dvs和hook部分忽略阅读,不影响整体框架。dvs是说输入是否是动态视频,hook是是否保存一些过程值)
前提:
x:输入
T:序列长度
B:batch长度,批大小
C:通道数
H:长度(行)
W:宽度(列)
x的形状是(T,B,C,H,W)
N:像素个数(N=H*W)
我们为了便于理解将T, B, C, H, W分别假设为1, 2, 4, 3, 3,num_heads=2
准备工作
1、首先复制一个原始x为identity,并初始化其他参数。
T, B, C, H, W = x.shape
identity = x
N = H * W
identity = x
2、将x变为脉冲信号
x = self.shortcut_lif(x)
3、新建一个变量x_for_qkv,存储x扁平化的结果。为了方便后续qkv的操作。
x_for_qkv = x.flatten(0, 1)
qkv初始化
4、得到q_conv_out(也就是注意力机制中的q,为了方便后续就将其叫为q)。首先将x_for_qkv进行一个二维卷积,然后再进行一个批归一化。并在同时reshape一下使得q的形状为(T, B, C, H, W)。并将q 变为脉冲形式。
在上述假设中此时的q为脉冲,并且形状为(1, 2, 4, 3, 3)
q_conv_out = self.q_conv(x_for_qkv)。#卷积
q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, H, W).contiguous()。#批归一化并reshape
q_conv_out = self.q_lif(q_conv_out) #变脉冲
5、为了便于后续处理,还需要将形状改变一下。经过下面的操作后变为(T, B, num_heads, N, C/num_heads)。
现在的q的形状变为了(1, 2, 2, 9, 2)
q = (
q_conv_out.flatten(3)
.transpose(-1, -2)
.reshape(T, B, N, self.num_heads, C // self.num_heads)
.permute(0, 1, 3, 2, 4)
.contiguous()
)
6、k和v执行相同的代码。
卷积-> 批归一化-> reshape-> 变为脉冲 ->reshape
此时的k和v形状也全部变为了(1, 2, 2, 9, 2)
k_conv_out = self.k_conv(x_for_qkv)
k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, H, W).contiguous()
k_conv_out = self.k_lif(k_conv_out)
k = (
k_conv_out.flatten(3)
.transpose(-1, -2)
.reshape(T, B, N, self.num_heads, C // self.num_heads)
.permute(0, 1, 3, 2, 4)
.contiguous()
)
v_conv_out = self.v_conv(x_for_qkv)
v_conv_out = self.v_bn(v_conv_out).reshape(T, B, C, H, W).contiguous()
v_conv_out = self.v_lif(v_conv_out)
v = (
v_conv_out.flatten(3)
.transpose(-1, -2)
.reshape(T, B, N, self.num_heads, C // self.num_heads)
.permute(0, 1, 3, 2, 4)
.contiguous()
) # T B head N C//h
注意力计算
7、k和v进行逐元素相乘。然后再按照倒数第二维度进行求和。得到的结果再通过LIF进行脉冲转化。
kv = k.mul(v)。 #逐元素相乘,形状(1,2,2,9,2)
kv = kv.sum(dim=-2, keepdim=True) #按照N所在维度求和,形状变为(1,2,2,1,2)
kv = self.talking_heads_lif(kv)。 #变脉冲
8、q和kv(即由7得到的结果)进行逐元素相乘赋值给x。也就是进行哈达玛积。此时的x还是01脉冲。并且此时的x形状为(1, 2, 2, 9, 2)
x = q.mul(kv)
9、将x转换形状变为(T, B, C, H, W)
x = x.transpose(3, 4).reshape(T, B, C, H, W).contiguous()
10、对x进行展平-> 卷积-> 批归一化-> reshape,再将x加上原始x,也就是加上identify
x = (
self.proj_bn(self.proj_conv(x.flatten(0, 1)))
.reshape(T, B, C, H, W)
.contiguous()
)
x = x + identity