- transformer 结构
- transformer的掩码
- transformer中dropout出现的地方
- transformer中layer_norm出现的地方
- torch.nn.Transformer的mask使用
- 句子的处理流程
transformer 结构
transformer的掩码
注意力方式 | 掩码 |
---|---|
自注意力 | src_mask = src_att_mask + src_pad_mask |
自注意力 | tgt_mask = tgt_att_mask + tgt_pad_mask |
交叉注意力 | mem_mask = mem_att_mask + mem_pad_mask |
掩码类型 | 作用 |
---|---|
*_pad_mask | 为了让当前单词不注意到填充的单词. 一般为不规则锯齿形 |
*_att_mask | 为了让当前单词不注意到它“不该”注意到的单词。例如tgt_att_mask为上三角矩阵 |
掩码类型 | 初始形状 | 拓展后的形状 |
---|---|---|
src_mask | (batch_size, src_len) | (batch_size, head, src_len, src_len) |
tgt_mask | (batch_size, tgt_len) | (batch_size, head, tgt_len, tgt_len) |
mem_mask | (batch_size, src_len) | (batch_size, head, tgt_len, src_len) |
transformer中dropout出现的地方
-
位置嵌入之后。
-
注意力softmax之后。
-
多头注意力之后的线性变换后。
-
feed_forward 第二个线性层之后。
transformer中layer_norm出现的地方
-
残差块之前(或之后)
-
(最后一个Decoder之后)
torch.nn.Transformer的mask使用
Binary and float masks are supported. For a binary mask, a True
value indicates that the corresponding position is not allowed to attend. For a float mask, the mask values will be added to the attention weight.
If both attn_mask and key_padding_mask are supplied, their types should match.