Talking-Heads Attention

paper:Talking-Heads Attention

在CaiT这篇文章中,作用采用了talking-heads attention,这里做一下解释。

在原始multi-head self-attention中,各个head的计算是独立进行的,多个head的输出最后concat到一起,然后再经过一个线性变换得到最终的输出。

本文提出了在softmax操作的前后引入跨注意力头维度的线性变换,从而使每个self-attention函数依赖于所有的key和query。

下面分别是timm中普通Attention和TalkingHeadAttention的实现

# class Attention
def forward(self, x: torch.Tensor) -> torch.Tensor:  # (1,197,192)
    B, N, C = x.shape
    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
    # (1,197,576)->(1,197,3,3,64)->(3,1,3,197,64), (3, batch_size, num_heads, seq_len, head_dim), 3表示qkv
    q, k, v = qkv.unbind(0)  # (1,3,197,64)
    q, k = self.q_norm(q), self.k_norm(k)

    if self.fused_attn:  # False
        x = F.scaled_dot_product_attention(
            q, k, v,
            dropout_p=self.attn_drop.p if self.training else 0.,
        )
    else:
        # attn=softmax(qk)
        q = q * self.scale
        attn = q @ k.transpose(-2, -1)  # (1,3,197,197)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = attn @ v  # (1,3,197,64)

    x = x.transpose(1, 2).reshape(B, N, C)  # (1,197,3,64)->(1,197,192)
    x = self.proj(x)  # (1,197,192)
    x = self.proj_drop(x)
    return x

# class TalkingHeadAttn
def forward(self, x):
    B, N, C = x.shape  # (1,196,384)
    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)  # (1,196,1152)->(1,196,3,8,48)->(3,1,8,196,48)
    q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]  # (1,8,196,48)

    attn = q @ k.transpose(-2, -1)  # (1,8,196,196)

    attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)  # (1,196,196,8)->(1,196,196,8)->(1,8,196,196)

    attn = attn.softmax(dim=-1)  # (1,8,196,196)

    attn = self.proj_w(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)  # (1,196,196,8)->(1,196,8,8)->(1,8,196,196)
    attn = self.attn_drop(attn)

    x = (attn @ v).transpose(1, 2).reshape(B, N, C)  # (1,8,196,48)->(1,196,8,48)->(1,196,384)
    x = self.proj(x)  # (1,196,384)
    x = self.proj_drop(x)
    return x

从下图的对比看的更加清楚,左边是普通的attention,右边是talking-heads attention。左边的输入shape为(1, 197, 192),其中197=196+1是添加了class token,192是特征维度。右边的输入shape为(1, 196, 384),特征维度为384。左边num_heads=3,右边num_heads=8。因为左边的代码来自vision transformer,右边的代码来自CaiT,选择的具体模型variant不同,所以特征维度和head数量也不一样,但不影响。

可以看到,TalkingHeadAttention在计算softmax前后分别引入了一个线性变换self.proj_lself.proj_w,定义分别为self.proj_l = nn.Linear(num_heads, num_heads)self.proj_w = nn.Linear(num_heads, num_heads)。在线性变换前先对输入进行维度变换通过.permute(0, 2, 3 ,1)将num_head维度放到最后,因此线性变换是针对num_head维度的,从而实现跨head的交互,最后再permute回去。

  • 4
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

00000cj

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值