transformer 架构的 qkv-attention 太流行了, 以至于 torch 官方直接给出了 C 实现, 位于 torch._C._nn.py, 性能更高.
但它不会返回 attention_weight, 不便于 打board.
1. 方法签名
scaled_dot_product_attention
(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensorattn_mask
andis_causal
不允许被同时设置.
Args:
query (Tensor): Query tensor; shape :math:`(N, ..., L, E)`.
key (Tensor): Key tensor; shape :math:`(N, ..., S, E)`.
value (Tensor): Value tensor; shape :math:`(N, ..., S, Ev)`.
attn_mask (optional Tensor): Attention mask; shape must be broadcastable to the shape of attention weights,
which is :math:`(N,..., L, S)`. Two types of masks are supported.
A boolean mask where a value of True indicates that the element *should* take part in attention.
A float mask of the same type as query, key, value that is added to the attention score.
dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied
is_causal (bool): If set to true, the attention masking is a lower triangular matrix when the mask is a
square matrix. The attention masking has the form of the upper left causal bias due to the alignment
(see :class:`torch.nn.attention.bias.CausalBias`) when the mask is a non-square matrix.
An error is thrown if both attn_mask and is_causal are set.
scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set
to :math:`\frac{1}{\sqrt{E}}`.
Returns:
output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`.
2. 源码 debug 解读
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S