Bert注意力计算过程的维度变化分析
代码分析目标:了解Bert注意力计算过程中的维度变化
变量相关说明
变量 | batch_size | sequence_len | self.num_attention_heads | config.hidden_size |
---|---|---|---|---|
含义解释 | 单批次训练量 | 单序列长度 | 抽头个数 | Bert隐层大小 |
符号定义 | B | S | N | H |
取值举例 | B: 32 | S: 128 | N: 8 | H: 768 |
源码分析
# 注:取自 hugging face 团队实现的基于 pytorch 的 BERT 模型
class BERTSelfAttention(nn.Module):
# BERT 的 Self-Attention 类
def __init__(self, config):
# 初始化函数
super(BERTSelfAttention, self).__init__()
# H必须能被N整除(bert内H为768)
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_a