加法Transformer!Spike-driven Transformer部分代码解读 (NeurIPS 2023)

纯加法Transformer!脉冲神经网络和Transformer的Spike-driven Transformer部分代码解读

仅个人学习记录和个人拙见,有错误地方希望指导~

原文:Spike-driven Transformer
推荐阅读:纯加法Transformer!结合脉冲神经网络和Transformer的Spike-driven Transformer (NeurIPS 2023)

简单叙述

对transformer的整体框架转换如下图。因为我主要阅读的是Spike-Driven Self-Attention(SDSA)部分,所以对整体框架的解读就少了许多,下面只是简单介绍一下attention部分的qkv的改变。
请添加图片描述

SDSA的改进如下。也就是纯加法transformer的过程中对注意力机制的变化。图片中很形象,能够一目了然。主要是将矩阵乘积转换为了Hadamard multiply哈达玛积。还引用了点积等。在计算复杂度上也有了一定量的减少。
在这里插入图片描述

注意力部分代码解读

注意力部分实现在./module/ms_conv.py里,主要函数为MS_SSA_Conv()类。
主要看一下forward中对整个过程的改写。
(代码中dvs和hook部分忽略阅读,不影响整体框架。dvs是说输入是否是动态视频,hook是是否保存一些过程值)

在这里插入图片描述

前提
x:输入
T:序列长度
B:batch长度,批大小
C:通道数
H:长度(行)
W:宽度(列)
x的形状是(T,B,C,H,W)
N:像素个数(N=H*W)

我们为了便于理解将T, B, C, H, W分别假设为1, 2, 4, 3, 3,num_heads=2

准备工作

1、首先复制一个原始x为identity,并初始化其他参数。

T, B, C, H, W = x.shape
identity = x
N = H * W
identity = x

2、将x变为脉冲信号

x = self.shortcut_lif(x)

3、新建一个变量x_for_qkv,存储x扁平化的结果。为了方便后续qkv的操作。

x_for_qkv = x.flatten(0, 1)

qkv初始化

4、得到q_conv_out(也就是注意力机制中的q,为了方便后续就将其叫为q)。首先将x_for_qkv进行一个二维卷积,然后再进行一个批归一化。并在同时reshape一下使得q的形状为(T, B, C, H, W)。并将q 变为脉冲形式。
在上述假设中此时的q为脉冲,并且形状为(1, 2, 4, 3, 3)

q_conv_out = self.q_conv(x_for_qkv)#卷积
q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, H, W).contiguous()#批归一化并reshape
q_conv_out = self.q_lif(q_conv_out)  #变脉冲

5、为了便于后续处理,还需要将形状改变一下。经过下面的操作后变为(T, B, num_heads, N, C/num_heads)。
现在的q的形状变为了(1, 2, 2, 9, 2)

q = (
   q_conv_out.flatten(3)
    .transpose(-1, -2)
    .reshape(T, B, N, self.num_heads, C // self.num_heads)
    .permute(0, 1, 3, 2, 4)
    .contiguous()
)

6、k和v执行相同的代码。
卷积-> 批归一化-> reshape-> 变为脉冲 ->reshape
此时的k和v形状也全部变为了(1, 2, 2, 9, 2)

k_conv_out = self.k_conv(x_for_qkv)
k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, H, W).contiguous()
k_conv_out = self.k_lif(k_conv_out)
k = (
   k_conv_out.flatten(3)
   .transpose(-1, -2)
   .reshape(T, B, N, self.num_heads, C // self.num_heads)
   .permute(0, 1, 3, 2, 4)
   .contiguous()
)


v_conv_out = self.v_conv(x_for_qkv)
v_conv_out = self.v_bn(v_conv_out).reshape(T, B, C, H, W).contiguous()
v_conv_out = self.v_lif(v_conv_out)
v = (
   v_conv_out.flatten(3)
   .transpose(-1, -2)
   .reshape(T, B, N, self.num_heads, C // self.num_heads)
   .permute(0, 1, 3, 2, 4)
   .contiguous()
)  # T B head N C//h

注意力计算

7、k和v进行逐元素相乘。然后再按照倒数第二维度进行求和。得到的结果再通过LIF进行脉冲转化。

kv = k.mul(v)#逐元素相乘,形状(1,2,2,9,2)
kv = kv.sum(dim=-2, keepdim=True)  #按照N所在维度求和,形状变为(1,2,2,1,2)
kv = self.talking_heads_lif(kv)#变脉冲

8、q和kv(即由7得到的结果)进行逐元素相乘赋值给x。也就是进行哈达玛积。此时的x还是01脉冲。并且此时的x形状为(1, 2, 2, 9, 2)

x = q.mul(kv)

9、将x转换形状变为(T, B, C, H, W)

x = x.transpose(3, 4).reshape(T, B, C, H, W).contiguous()

10、对x进行展平-> 卷积-> 批归一化-> reshape,再将x加上原始x,也就是加上identify

x = (
self.proj_bn(self.proj_conv(x.flatten(0, 1)))
    .reshape(T, B, C, H, W)
    .contiguous()
)

x = x + identity
  • 35
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值