GPT2模型源码阅读系列(三)一Block

在上一篇GPT2模型源码阅读系列(二)一GPT2Model中,调用Block的部分为

                outputs = block(
                    hidden_states,
                    layer_past=layer_past,
                    attention_mask=attention_mask,
                    head_mask=head_mask[i],
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_attention_mask,
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                )
 

此时返回的outputs列表中的元素为:

  # <1> 第一个值为多头注意力聚合操作结果张量hidden_states输入前馈MLP层与残差连接
  之后得到的hidden_states张量,  形状为(batch_size, 1, n_state),
  all_head_size=n_state=nx=n_embd=768.
  
  # <2> 第二个值为上方的present张量, 其存储着past_key张量与这次迭代的
  key张量合并后的新key张量, 以及 past_value张量与这次迭代的
  value张量合并后的新value张量, 其形状为
  (2, batch_size, num_head, sql_len+1, head_features).
  
  # <3> 若output_attentions为True, 则第三个值为attn_outputs列表中的
  注意力分数张量w.
  
  # <4> 若此时进行了Cross Attention计算, 则第四个值为
  '交叉多头注意力计算结果列表cross_attn_outputs'中的
  交叉注意力分数张量cross_attention, 
  其形状为(batch_size, num_head, 1, enc_seq_len).

Block类中,主要结构为两个LayerNormalization层self.ln_1与self.ln_2、
一个Attention模块层self.attn、
一个前馈层self.mlp;
Attention层用来进行多头注意力聚合操作,前馈层用来进行全连接投影操作。

Cross_Attention 与 Masked_Multi_Self_Attention

若此时有编码器(encoder)中传过来的编码器隐藏状态encoder_hidden_states张量、encoder_attention_mask张量传入Block类中且config中的add_cross_attention超参数为True,则此时除了要进行GPT2中默认的Masked_Multi_Self_Attention计算之外,还需要和编码器(encoder)中传过来的编码器隐藏状态encoder_hidden_states张量进行Cross_Attention计算过程(self.crossattention)。

其中self.crossattention的Cross_Attention运算过程与self.attn的Masked_Multi_Self_Attention运算过程几乎相同, 其不同点在于self_attention将hidden拆成q,k,v三个变量,而cross_attention将hidden直接当作q,将encoder_hidden_states拆成k和v:

<1> self.attn的Masked_Multi_Self_Attention运算过程

self.attn的Masked_Multi_Self_Attention运算是将LayerNormalization之后的hidden_states张量通过Attention类中的 self.c_attn = Conv1D(3 * n_state, nx) 操作将hidden_states张量的形状由 (batch_size, 1, 768) 投影为 (batch_size, 1, 3 * 768),再将投影后的hidden_states张量在第三维度(dim=2)上拆分为三份,将其分别赋为query、key、value,其形状都为(batch_size, 1, 768),此时n_state = nx = num_head*head_features = 768。

之后经过Attention类中的split_heads()函数拆分注意力头且key、value张量分别与past_key、past_value张量合并之后:
query张量的形状变为(batch_size, num_head, 1, head_features),
key张量的形状变为(batch_size, num_head, head_features, sql_len+1),
value张量的形状变为(batch_size, num_head, sql_len+1, head_features).

之后便会利用得到的query、key、value进行多头注意力聚合操作,此时计算出的注意力分数张量w的形状为 (batch_size, num_head, 1, sql_len+1)。

<2> self.crossattention的Cross_Attention运算过程

self.crossattention的Cross_Attention运算过程则是将LayerNormalization之后的hidden_states张量通过Attention类中的 self.q_attn = Conv1D(n_state, nx) 操作将hidden_states张量的形状由(batch_size, 1, 768)投影为(batch_size, 1, 768),将此投影之后的hidden_states张量赋为query张量。

再将此时从编码器(encoder)中传过来的编码器隐藏状态 encoder_hidden_states 通过 Attention类中的 self.c_attn = Conv1D(2 * n_state, nx) 操作将encoder_hidden_states张量的形状由(batch_size, enc_seq_len, 768)投影为(batch_size, enc_seq_len, 2 * 768),将投影后的encoder_hidden_states张量在在第三维度(dim=2)上拆分为两份分别赋为key、value,其形状都为(batch_size, enc_seq_len, 768),此时n_state = nx = num_head*head_features = 768。

之后经过Attention类中的split_heads()函数拆分注意力头之后:
query张量的形状变为(batch_size, num_head, 1, head_features),
key张量的形状变为(batch_size, num_head, head_features, enc_seq_len),
value张量的形状变为(batch_size, num_head, enc_seq_len, head_features).

之后便会利用此时得到的query、key、value张量进行交叉多头注意力聚合操作,此时计算出的cross_attention张量形状为(batch_size, num_head, 1, enc_seq_len)。

class Block(nn.Module):
    def __init__(self, n_ctx, config, scale=False):
        super().__init__()
        # config对应的GPT2Config()类中, n_embd属性默认为768, 因此此处hidden_size即为768.
        hidden_size = config.n_embd
        # config对应的GPT2Config()类中, n_inner属性默认为None, 因此此处inner_dim一般都为4 * hidden_size.
        inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        # 此处n_ctx即等于config对应的GPT2Config()类中的n_ctx属性, 其值为1024.
        self.attn = Attention(hidden_size, n_ctx, config, scale)
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)

        if config.add_cross_attention:
            self.crossattention = Attention(hidden_size, n_ctx, config, scale, is_cross_attention=True)
            self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.mlp = MLP(inner_dim, config)

    def forward(
        self,
        hidden_states,
        layer_past=None,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        use_cache=False,
        output_attentions=False,
    ):
        
        '''
        <1> 此时的隐藏状态hidden_states的形状为 (batch_size, 1, nx), 此时nx = n_state = n_embed = all_head_size = 768,
            即此时隐藏状态hidden_states的形状为(batch_size, 1, 768)<2> 此时layer_past为一个存储着past_key张量与past_value张量的大张量, 其
             形状为(2, batch_size, num_head, sql_len, head_features).
        <3> attention_mask张量为注意力遮罩张量, 其让填充特殊符[PAD]处的注意力分数极小,
             其embedding嵌入值基本不会在多头注意力聚合操作中被获取到.
        '''

        # 将此时输入的隐藏状态hidden_states先输入进LayerNormalization层进行层标准化计算后,
        # 再将标准化结果输入进'多头注意力计算层self.attn()'中进行多头注意力聚合操作计算.
        # 此时返回的attn_outputs列表中:
        # <1> 第一个值为多头注意力聚合操作结果张量a, 形状为(batch_size, 1, all_head_size), all_head_size=n_state=nx=n_embd=768.
        # <2> 第二个值为上方的present张量, 其存储着past_key张量与这次迭代的key张量合并后的新key张量, 以及
        #     past_value张量与这次迭代的value张量合并后的新value张量, 其形状为(2, batch_size, num_head, sql_len+1, head_features).
        # <3> 若output_attentions为True, 则第三个值为attn_outputs列表中的注意力分数张量w.
        attn_outputs = self.attn(
            self.ln_1(hidden_states),
            layer_past=layer_past,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )

        # 此时的attn_output张量为返回的attn_outputs列表中第一个值:
        # 多头注意力聚合操作结果张量a, 形状为(batch_size, 1, all_head_size), all_head_size=n_state=nx=n_embd=768.
        attn_output = attn_outputs[0]  # output_attn列表: a, present, (attentions)
        outputs = attn_outputs[1:]

        # residual connection, 进行残差连接.
        # 此时attn_output张量形状为(batch_size, 1, all_head_size), all_head_size=n_state=nx=n_embd=768.
        # hidden_states的形状为(batch_size, 1, 768).
        hidden_states = attn_output + hidden_states


        if encoder_hidden_states is not None:
            # add one self-attention block for cross-attention
            assert hasattr(
                self, "crossattention"
            ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"


            '''此时self.crossattention()的Cross_Attention运算过程与self.attn()的Attention运算过程几乎相同, 其不同点在于:

            <1> self.attn()的Attention运算是将LayerNormalization之后的hidden_states通过'self.c_attn = Conv1D(3 * n_state, nx)
            (165行代码)'将hidden_states的形状由(batch_size,1, 768)投影为(batch_size,1, 3 * 768), 再将投影后的hidden_states
            在第三维度(dim=2)上拆分为三份分别赋为query、key、value, 其形状都为(batch_size, 1, 768);
			此时n_state = nx = num_head*head_features = 768.
			
            之后经过split_heads()函数拆分注意力头且key、value张量分别与past_key、past_value张量合并之后:
            query张量的形状变为(batch_size, num_head, 1, head_features),
            key张量的形状变为(batch_size, num_head, head_features, sql_len+1),
            value张量的形状变为(batch_size, num_head, sql_len+1, head_features).

            <2> self.crossattention()的Cross_Attention运算过程则是将LayerNormalization之后的hidden_states通过
            'self.q_attn = Conv1D(n_state, nx)(第163行代码)'将hidden_states的形状由(batch_size,1, 768)投影为(batch_size,1, 768),
            将此投影之后的hidden_states赋值作为query张量;
            再将此时从编码器(encoder)中传过来的编码器隐藏状态encoder_hidden_states通过'self.c_attn = Conv1D(2 * n_state, nx)
            (162行代码)'将encoder_hidden_states的形状由(batch_size, enc_seq_len, 768)投影为(batch_size, enc_seq_len, 2 * 768),
            将投影后的encoder_hidden_states在在第三维度(dim=2)上拆分为两份分别赋为key、value,
            其形状都为(batch_size, enc_seq_len, 768); 此时n_state = nx = num_head*head_features = 768.
            
            之后经过split_heads()函数拆分注意力头之后:
            query张量的形状变为(batch_size, num_head, 1, head_features),
            key张量的形状变为(batch_size, num_head, head_features, enc_seq_len),
            value张量的形状变为(batch_size, num_head, enc_seq_len, head_features).
            此时计算出的cross_attention张量形状为(batch_size, num_head, 1, enc_seq_len).'''

            # 此时将上方的隐藏状态hidden_states(Attention运算结果+Attention运算前的hidden_states)先输入进LayerNormalization
            # 层进行层标准化计算后, 再将标准化结果输入进'交叉多头注意力计算层self.crossattention()'中与编码器传入的隐藏状态
            # encoder_hidden_states进行交叉多头注意力聚合操作计算.
            # 此时返回的cross_attn_outputs列表中:
            # <1> 第一个值为与编码器传入的隐藏状态encoder_hidden_states进行交叉多头注意力聚合操作的结果张量a,
            #     形状为(batch_size, 1, all_head_size), all_head_size=n_state=nx=n_embd=768。
            # <2> 第二个值仍为present张量, 但由于此时是做'交叉多头注意力计算self.crossattention()',此时输入进self.crossattention()
            #     函数的参数中不包含layer_past(来自past_key_values列表)的past_key与past_value张量, 因此此时的present为(None,),
            #     详细代码可见本脚本代码357, 因此此处用不到'交叉多头注意力计算结果列表cross_attn_outputs'中的present,
            #     将其舍弃(代码第528)。
            # <3> 若output_attentions为True, 则第三个值为: 交叉注意力分数张量w, 即cross attentions,
            #      cross_attention张量形状为(batch_size, num_head, 1, enc_seq_len).
            cross_attn_outputs = self.crossattention(
                self.ln_cross_attn(hidden_states),
                attention_mask=attention_mask,
                head_mask=head_mask,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                output_attentions=output_attentions,
            )
            attn_output = cross_attn_outputs[0]
            # residual connection
            hidden_states = hidden_states + attn_output
            # cross_attn_outputs[2:] add cross attentions if we output attention weights,
            # 即将'交叉多头注意力计算结果列表cross_attn_outputs'中的交叉注意力分数张量cross_attention保存为此时的
            # outputs列表中的最后一个元素.
            outputs = outputs + cross_attn_outputs[2:]


        feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states))
        # residual connection
        hidden_states = hidden_states + feed_forward_hidden_states

        outputs = [hidden_states] + outputs

        # 此时返回的outputs列表中的元素为:
        # <1> 第一个值为多头注意力聚合操作结果张量hidden_states输入前馈MLP层与残差连接之后得到的最终hidden_states张量,
        #     形状为(batch_size, 1, n_state), all_head_size=n_state=nx=n_embd=768.
        # <2> 第二个值为上方的present张量, 其存储着past_key张量与这次迭代的key张量合并后的新key张量, 以及
        #     past_value张量与这次迭代的value张量合并后的新value张量, 其形状为(2, batch_size, num_head, sql_len+1, head_features).
        # <3> 若output_attentions为True, 则第三个值为attn_outputs列表中的注意力分数张量w.
        # <4> 若此时进行了Cross Attention计算, 则第四个值为'交叉多头注意力计算结果列表cross_attn_outputs'中的
        #     交叉注意力分数张量cross_attention, 其形状为(batch_size, num_head, 1, enc_seq_len).
        return outputs  # hidden_states, present, (attentions, cross_attentions)


  • 2
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值