一、计算过程理解
1、我们直接用torch实现一个 S e l f A t t e n t i o n Self Attention SelfAttention:
首先定义三个线性变换矩阵, q u e r y , k e y , v a l u e query, key, value query,key,value:
class BertSelfAttention(nn.Module):
self.query = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
self.key = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
self.value = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
注意,这里的 q u e r y , k e y , v a l u e query, key, value query,key,value只是一种操作(线性变换)的名称,实际的 Q / K / V Q/K/V Q/K/V是这三个线性操作的输出,三个变换的输入都是 768 768 768维,输出都是 768 768 768维,也就是三个线性变换矩阵的维度都为 ( 768 , 768 ) (768, 768) (768,768)。
2、假设三种操作的输入都是同一个矩阵,这里暂且定为长度为 6 6 6的句子,每个 t o k e n token token的特征维度是 768 768 768,那么输入就是 ( 6 , 768 ) (6, 768)