头的数量不一致的原因
-
模型复杂性和计算资源:
- 查询(query)的头数多,表示模型可以捕获更多的上下文信息和关系,从而提高模型的表达能力。
- 键(key)和值(value)的头数少,可以减少计算量和内存占用,从而提高计算效率和减少资源消耗。
-
灵活性和控制:
- 通过设置不同的查询和键/值头数,可以灵活地控制模型的复杂性和计算负载。特别是在大模型中,不同的任务可能需要不同的注意力分配。
-
架构设计:
- 在一些特定的架构设计中,查询头数的增加可以提高模型的表现力,而键/值头数的减少可以保持模型的计算和存储成本在合理范围内。
实现细节
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
#键/值头的数量小于查询头的数量时重复键/值头
self.head_dim = args.dim // args.n_heads
self.n_rep
是查询头相对于键/值头的重复次数,用于在前向传播中重复键和值,以匹配查询头的数量。
计算中的重复操作
在前向传播中,通过以下函数来重复键和值的头:
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
else:
return ( x[:, :, :, None, :] .expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim) )
该函数将键和值的头通过 expand
和 reshape
操作重复,以匹配查询头的数量。
实例中的应用
在 Attention
类的 forward
方法中,键和值通过 repeat_kv
函数进行重复:
keys = repeat_kv(keys, self.n_rep) values = repeat_kv(values, self.n_rep)
这确保了查询、键和值在计算注意力权重时具有相同的头数,从而可以进行矩阵乘法计算。
总结
通过允许查询和键/值具有不同的头数,模型可以在不显著增加计算和存储成本的情况下,提高其表示能力和灵活性。这种设计在大型模型和资源受限的情况下特别有用。