几种不同的self-attention
- Multi-head attention
- Multi-query attention
- grouped-query attention
在进行大模型的训练和推理中会大量的使用self-attention,在显存中需要保存self-attention中的query、key和value矩阵。Multi-head attention中每个头都有对应的query、key和value矩阵,因此会占用大量显存。而Multi-query attention中所有的头共用一个key和value矩阵来降低在模型训练和推理过程中大量占用显存的情况,不过这种方式可能会影响模型性能。grouped-query attention通过分组的方式,同一个组内共用一个key和value矩阵,当分组数与头数相同时即为Multi-head attention,当分组数为1时则为Multi-query attention。
以下是三种不同self-attention代码:
- Multi-head attention
import torch
# 增量式多头注意力机制
def MultiheadSelfAttentionIncremental():
"""
d_model:模型隐藏层大小
b:批大小
h:头的数量
d_k:key的维度
d_v:value的维度
"""
# 模型隐藏层512,批大小为32,头的数量为8,key和value为512//8
d_model, b, h, d_k, d_v = 512, 32, 8, (512 // 8), (512 // 8)
m = 5 # 假设已经缓存的token数量
# 已经计算好的key和value矩阵,此处是假设已缓存了5个token的结果(随机初始化)
prev_K = torch.rand(b, h, m, d_k)
prev_V = torch.rand(b, h, m, d_v)
X = torch.rand(b, d_model) # Query
M = torch.rand(b, d_model) # Key and Value
# q、k、v和输出的权重矩阵
P_q = torch.rand(h, d_model, d_k) # W_q
P_k = torch.rand(h, d_model, d_k) # W_k
P_v = torch.rand(h, d_model, d_v) # W_v
P_o = torch.rand(h, d_model, d_v) # W_o
q = torch.einsum("bd,hdk->bhk", X, P_q) # 多维线性代数数组操作,将从输入到Query
new_K = torch.concat(
[prev_K, torch.einsum("bd,hdk->bhk", M, P_k).unsqueeze(2)], axis=2
) # prev_K(批, 头, 已有token, key维度),通过torch.einsum生成新的token的key,将两个矩阵在已有token这个维度上上进行矩阵拼接
new_V = torch.concat(
[prev_V, torch.einsum("bd,hdv->bhv", M, P_v).unsqueeze(2)], axis=2
)
# 进行softmax计算
logits = torch.einsum("bhk,bhmk->bhm", q, new_K) # 计算qk
weights = torch.softmax(logits, dim=-1)
O = torch.einsum("bhm,bhmv->bhv", weights, new_V)
y = torch.einsum("bhv,hdv->bd", O, P_o)
return y, new_K, new_V
if __name__ == "__main__":
print(MultiheadSelfAttentionIncremental())
- multi-query attention
import torch
# 增量式Multi-query attention
def MultiquerySelfAttentionIncremental():
# 以下参数分别为模型隐藏层大小,批,头,key,value
d, b, h, k, v = 512, 32, 8, (512 // 8), (512 // 8)
m = 5 # 假设序列已有5个token
# 初始化已有5个token的key和value 缓存
prev_K = torch.rand(b, m, k) # 由于multi-query attention中无论多少个头都只有一个key和value矩阵,因此比较multi-head attention中的代码少了头这个维度
prev_V = torch.rand(b, m, v)
X = torch.rand(b, d) # 随机初始化Query
M = torch.rand(b, d) # 随机初始化Key和Value
# q、k、v和输出的权重矩阵
P_q = torch.rand(h, d, k) # W_q
P_k = torch.rand(d, k) # W_k
P_v = torch.rand(d, v) # W_v
P_o = torch.rand(h, d, v) # W_o
q = torch.einsum("bd,hdk->bhk", X, P_q)
K = torch.concat([prev_K, torch.einsum("bd,dk->bk", M, P_k).unsqueeze(1)], axis=1)
V = torch.concat([prev_V, torch.einsum("bd,dv->bv", M, P_v).unsqueeze(1)], axis=1)
logits = torch.einsum("bhk,bmk->bhm", q, K)
weights = torch.softmax(logits, dim=-1)
O = torch.einsum("bhm,bmv->bhv", weights, V)
y = torch.einsum("bhv,hdv->bd", O, P_o)
return y, K, V
if __name__ == "__main__":
print(MultiquerySelfAttentionIncremental())
- grouped-query attention
"""
在grouped-query attention中
当组数与头数相同时则为multi-head attention
当组数为1时则为multi-query attention
"""
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1: # MHA
return x
return ( # MQA / GQA
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size # 此处
self.n_rep = self.n_local_heads // self.n_local_kv_heads # 此处 几个组
self.head_dim = args.dim // args.n_heads
self.wq = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wk = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim, # 初始化为单个组内的一份
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wv = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim, # # 初始化为单个组内的一份
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wo = RowParallelLinear(
args.n_heads * self.head_dim,
args.dim,
bias=False,
input_is_parallel=True,
init_method=lambda x: x,
)
self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
# repeat k/v heads if n_kv_heads < n_heads # 单个组扩展为完整head
keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)