Talking-Heads Attention

本文探讨了谷歌提出的Talking-HeadsAttention,它通过线性映射增强多头注意力间的交互,提升模型性能。介绍了其原理、基本实现方式,并提供了相关实现代码链接。对比了与标准Multi-HeadAttention的区别,展示了在语言处理中的应用潜力。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1. Multi-Head Attention

当前最流行的Attention机制当属 Scaled-Dot Attention (源于 Attention Is All You Need) ,即:

在这里插入图片描述

基于上述 Scaled-Dot Attention 下标准的 Multi-Head Attention 如下所示:

在这里插入图片描述

2. Talking-Heads Attention

近日,来自 Google 的研究团队提出一种「交谈注意力机制」(Talking-Heads Attention),在 softmax 操作前后引入对多头注意力之间的线性映射,以此增加多个注意力机制间的信息交流。这样的操作虽然增加了模型的计算复杂度,却能够在多项语言处理问题上取得更好的效果。

2.1 基本原理

当前的Multi-Head Attention每个head的运算是相互孤立的,而通过将它们联系(Talking)起来,则可以得到更强的Attention设计

在这里插入图片描述
如上图,就是将多头注意力用一个参数矩阵重新融合成多个混合注意力。每个新的得到的混合注意力都融合了原先的各head注意力。
注:
1、这里省略了缩放因子 {d_k}^1/2
2、新生成的多个混合注意力可以多于原先的h

2.2 具体实现

【参考博客】:

### 多头注意力机制概念 多头注意力机制是注意力机制的一种扩展形式,在深度学习领域被广泛应用。通过将输入的查询(Q)、键(K)和值(V)分成多个部分,即所谓的“头部”,从而允许模型关注来自不同表示子空间的信息[^1]。 ### 工作原理 具体来说,对于给定的一组输入序列,这些输入会被线性变换映射到不同的低维空间形成各自的 Q、K 和 V 向量。接着,每个这样的三元组都会经历标准缩放点积注意计算过程: \[ \text{Attention}(Q,K,V)=\text{softmax}\left(\frac{Q K^{T}}{\sqrt {d_k}}\right)V \] 其中 \( d_k \) 是键向量维度大小。此操作会在每一个头上独立执行多次,最终得到的结果会再次拼接起来并通过另一个线性转换层来获得最后输出[^3]。 ### 实现方法 下面是一个简单的 Python 代码片段展示如何构建一个多头自注意力模块: ```python import torch.nn as nn import math class MultiHeadAttention(nn.Module): def __init__(self, embed_size, num_heads): super(MultiHeadAttention, self).__init__() self.embed_size = embed_size self.num_heads = num_heads assert (embed_size % num_heads == 0), "Embed size needs to be divisible by heads" self.depth = embed_size // num_heads self.w_q = nn.Linear(embed_size, embed_size) self.w_k = nn.Linear(embed_size, embed_size) self.w_v = nn.Linear(embed_size, embed_size) self.fc_out = nn.Linear(embed_size, embed_size) def forward(self, queries, keys, values, mask=None): N = queries.shape[0] query_len, key_len, value_len = queries.shape[1], keys.shape[1], values.shape[1] # Split the embedding into `num_heads` different pieces. queries = self.split_heads(self.w_q(queries)) keys = self.split_heads(self.w_k(keys)) values = self.split_heads(self.w_v(values)) energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) if mask is not None: energy = energy.masked_fill(mask == 0, float("-1e20")) attention = torch.softmax(energy / (self.depth ** 0.5), dim=3) out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape( N, query_len, self.embed_size ) return self.fc_out(out) def split_heads(self, x): batch_size, seq_length, _ = x.size() return ( x.view(batch_size, seq_length, self.num_heads, self.depth) .transpose(1, 2) ) ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值