GPT2模型源码阅读系列(四)一Attention

终于到了重头戏Attention类,主要关注点为cross_attention, self_attention, split_head, layer_past

Attention类中的merge_heads()函数用来将多头注意力聚合操作结果张量a的注意力头维度进行合并,令多头注意力聚合操作结果张量a的形状由
(batch_size, num_head, 1, head_features)变为(batch_size, 1, all_head_size)
split_heads()函数用来对query张量、key张量与value张量进行注意力头拆分
而prune_heads()函数则可以用来删除一些注意力头。
而Attention类中最核心的函数为_attn()函数, _attn()函数即为用来对query、key、value三个张量进行多头注意力聚合操作的函数。

Cross_Attention与Masked_Multi_Self_Attention

而在Attention()类的forward()函数中一开始便会判断是否传入了编码器(encoder)中传过来的编码器隐藏状态encoder_hidden_states张量。

若此时传入了编码器隐藏状态encoder_hidden_states张量,则此时Attention()类中会进行 ‘Cross_Attention’ 的计算过程;

'self.q_attn = Conv1D(n_state, nx)(第168行代码)'将hidden_states的形状由(batch_size,1, 768)投影为(batch_size,1, 768),
           将此投影之后的hidden_states赋值作为query张量;
           再将此时从编码器(encoder)中传过来的编码器隐藏状态encoder_hidden_states通过'self.c_attn = Conv1D(2 * n_state, nx)
            (164行代码)'将encoder_hidden_states的形状由(batch_size, enc_seq_len, 768)投影为(batch_size, enc_seq_len, 2 * 768),
            将投影后的encoder_hidden_states在在第三维度(dim=2)上拆分为两份分别赋为key、value,
            其形状都为(batch_size, enc_seq_len, 768);  
            此时n_state = nx = num_head*head_features = 768.
            
            之后经过split_heads()函数拆分注意力头之后:
            query张量的形状变为(batch_size, num_head, 1, head_features),
            key张量的形状变为(batch_size, num_head, head_features, enc_seq_len),
            value张量的形状变为(batch_size, num_head, enc_seq_len, head_features).
            
            此时计算出的cross_attention张量形状为(batch_size, num_head, 1, enc_seq_len).'''

            query = self.q_attn(hidden_states)
            key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
            attention_mask = encoder_attention_mask

若此时未传入编码器隐藏状态encoder_hidden_states张量,则此时Attention()类中便会进行GPT2中默认的

‘Masked_Multi_Self_Attention’ 计算过程。

        else:
            '''此时隐藏状态hidden_states的形状为(batch_size, 1, 768), 将其输入进全连接层self.c_attn中后,Conv1D(3 * n_state, nx)操作(nx=n_state=768)便会将hidden_states的第三维度数由 768维 投影为 3 * 768,
            此时的hidden_states张量的形状为(batch_size, 1, 3 * 768), 最后将hidden_states张量在第三个维度(维度数3 * 768)上
            切分为三块, 将这切分出的三块各当成query, key, value张量, 则每个张量的形状都为(batch_size, 1, 768).
            此时n_state = nx = num_head*head_features = 768.
            
            之后经过split_heads()函数拆分注意力头且key、value张量分别与past_key、past_value张量合并之后:
            query张量的形状变为(batch_size, num_head, 1, head_features),
            key张量的形状变为(batch_size, num_head, head_features, sql_len+1),
            value张量的形状变为(batch_size, num_head, sql_len+1, head_features).'''
            query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)

split_head()

    def split_heads(self, x, k=False):
        # 此时new_x_shape为: (batch_size, sql_len, num_head, head_features)
        new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        # 将输入的张量x(可能为query、key、value张量)变形为: (batch_size, sql_len, num_head, head_features).
        x = x.view(*new_x_shape)  # in Tensorflow implem: fct split_states

        # 若此时输入的张量为key张量,则需要将key张量再变形为(batch_size, num_head, head_features, sql_len).
        # 因为此时key张量需要以[query * key]的形式与query张量做内积运算, 因此key张量需要将head_features变换到第三维度,
        # 将sql_len变换到第四维度,这样[query * key]内积运算之后的注意力分数张量的形状才能符合(batch_size, num_head, sql_len, sql_len).
        if k:
            return x.permute(0, 2, 3, 1)  # (batch_size, num_head, head_features, sql_len)

        # 若此时输入的张量为query张量或value张量, 则将张量维度再变换为(batch_size, num_head, sql_len, head_features)即可,
        # 即将sql_len与num_head调换维度.
        else:
            return x.permute(0, 2, 1, 3)  # (batch_size, num_head, sql_len, head_features)

layer_past张量

此外,此时Attention类的forward()函数中也会判断是否传入了layer_past张量

此时若Attention类的forward()函数中传入了layer_past张量,则必为进行GPT2中默认的 ‘多头注意力聚合操作Masked_Multi_Self_Attention’ 计算过程,因为在进行 ‘交叉多头注意力聚合操作Cross_Attention’ 的计算过程时无需用到layer_past张量。

此时,根据layer_past张量中保存的past_key张量与past_value张量计算当前迭代中新的key张量与value张量的过程为:

<1> 当前迭代中新的key张量

通过layer_past[0].transpose(-2, -1)操作将past_key张量的形状变为
(batch_size, num_head, head_features, sql_len),

key张量的形状为(batch_size, num_head, head_features, 1),便可将past_key张量与key张量在最后一个维度(dim=-1)处进行合并,这样就将当前token的key部分加入了past_key的seq_len中,以方便模型在后面预测新的token,此时新的key张量的形状为: (batch_size, num_head, head_features, sql_len+1),new_seq_len为sql_len+1。

<2> 当前迭代中新的value张量

past_value张量不用变形,其形状为
(batch_size, num_head, sql_len, head_features),
而此时value张量的形状为(batch_size, num_head, 1, head_features),将past_value张量与value张量在倒数第二个维度(dim=-2)处进行合并,这样就将当前token的value部分加入了past_value的seq_len中,以方便模型在后面预测新的token,此时新的value张量的形状为: (batch_size, num_head, sql_len+1, head_features),new_seq_len为sql_len+1。

class Attention(nn.Module):
    def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False):
        super().__init__()

        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]
        # 利用断言函数判断此时隐藏状态的维度数n_state除以注意力头数config.n_head之后是否能整除.
        assert n_state % config.n_head == 0

        # 下方的self.register_buffer()函数的操作相当于创建了两个Attention类中的self属性, 即为self.bias属性
        # 与self.masked_bias属性;
        # 其中self.bias属性为一个下三角矩阵(对角线下元素全为1, 对角线上元素全为0), 其形状为(1, 1, n_ctx, n_ctx),
        # 也即形状相当于(1, 1, 1024, 1024);
        # 而self.masked_bias属性则为一个极大的负数-1e4;
        self.register_buffer(
            "bias", torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(1, 1, n_ctx, n_ctx)
        )
        self.register_buffer("masked_bias", torch.tensor(-1e4))


        self.n_head = config.n_head
        self.split_size = n_state
        self.scale = scale

        self.is_cross_attention = is_cross_attention
        if self.is_cross_attention:
            # self.c_attn = Conv1D(2 * n_state, nx)相当于全连接层, 其将输入张量的最后一个维度的维度数由nx(768)投影为
            # 2 * n_state(2*768), 此时n_state = nx = num_head*head_features = 768.
            self.c_attn = Conv1D(2 * n_state, nx)

            # self.q_attn = Conv1D(n_state, nx)相当于全连接层, 其将输入张量的最后一个维度的维度数由nx(768)投影为
            # n_state(768), 此时n_state = nx = num_head*head_features = 768.
            self.q_attn = Conv1D(n_state, nx)

        else:
            # self.c_attn = Conv1D(3 * n_state, nx)相当于全连接层, 其将输入张量的最后一个维度的维度数由nx(768)投影为
            # 2 * n_state(2*768), 此时n_state = nx = num_head*head_features = 768.
            self.c_attn = Conv1D(3 * n_state, nx)

        # 此处self.c_proj()Conv1D(n_state, nx)函数(all_head_size=n_state=nx=768), 相当于一个全连接层的作用,
        # 其将此时的多头注意力聚合操作结果张量a的最后一个维度all_head_size由n_state(768)的维度数投影为nx(768)的维度数.
        self.c_proj = Conv1D(n_state, nx)
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
        self.pruned_heads = set()


    # prune_heads()可结合 https://github.com/huggingface/transformers/issues/850 理解.
    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        heads, index = find_pruneable_heads_and_indices(
            heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
        )
        index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])

        # Prune conv1d layers
        self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
        self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)

        # Update hyper params
        self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
        self.n_head = self.n_head - len(heads)
        self.pruned_heads = self.pruned_heads.union(heads)


    def merge_heads(self, x):
        # 此时x为: 利用计算得到的注意力分数张量对value张量进行注意力聚合后得到的注意力结果张量.
        # x的形状为(batch_size, num_head, sql_len, head_features).

        # 此时先将注意力结果张量x的形状变为(batch_size, sql_len, num_head, head_features)
        x = x.permute(0, 2, 1, 3).contiguous()
        # new_x_shape为(batch_size, sql_len, num_head*head_features) =》(batch_size, sql_len, all_head_size)
        new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)

        # 此时将注意力结果张量x的注意力头维度num_head与注意力特征维度head_features进行合并变为all_head_size维度,
        # 注意力结果张量x的形状变为(batch_size, sql_len, all_head_size).
        return x.view(*new_x_shape)  # in Tensorflow implem: fct merge_states, (batch_size, sql_len, all_head_size).


    def split_heads(self, x, k=False):
        # 此时new_x_shape为: (batch_size, sql_len, num_head, head_features)
        new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        # 将输入的张量x(可能为query、key、value张量)变形为: (batch_size, sql_len, num_head, head_features).
        x = x.view(*new_x_shape)  # in Tensorflow implem: fct split_states

        # 若此时输入的张量为key张量,则需要将key张量再变形为(batch_size, num_head, head_features, sql_len).
        # 因为此时key张量需要以[query * key]的形式与query张量做内积运算, 因此key张量需要将head_features变换到第三维度,
        # 将sql_len变换到第四维度,这样[query * key]内积运算之后的注意力分数张量的形状才能符合(batch_size, num_head, sql_len, sql_len).
        if k:
            return x.permute(0, 2, 3, 1)  # (batch_size, num_head, head_features, sql_len)

        # 若此时输入的张量为query张量或value张量, 则将张量维度再变换为(batch_size, num_head, sql_len, head_features)即可,
        # 即将sql_len与num_head调换维度.
        else:
            return x.permute(0, 2, 1, 3)  # (batch_size, num_head, sql_len, head_features)


    def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False):
        
        '''
        此时query张量形状为: (batch_size, num_head, 1, head_features)
        key张量的形状为: (batch_size, num_head, head_features, sql_len+1)
        value张量的形状为: (batch_size, num_head, sql_len+1, head_features)

        此时key张量以[query * key]的形式与query张量做内积运算, key张量已在split_heads()操作与past_key合并操作中
        提前将head_features变换到第三维度, 将sql_len+1变换到第四维度,这样[query * key]内积运算之后的注意力分数张量w的
        形状才能符合(batch_size, num_head, 1, sql_len+1).
        '''
        w = torch.matmul(q, k)  # 注意力分数张量w: (batch_size, num_head, 1, sql_len+1)

        # 对注意力分数张量w中的值进行缩放(scaled), 缩放的除数为注意力头特征数head_features的开方值.
        if self.scale:
            w = w / (float(v.size(-1)) ** 0.5)

        # 此时nd与ns两个维度相当于1与seq_len+1
        nd, ns = w.size(-2), w.size(-1)

        # 此处的操作为利用torch.where(condition, x, y)函数,将注意力分数张量w在mask.bool()条件张量为True(1)的相同位置的值
        # 保留为w中的原值, 将在mask.bool()条件张量为True(0)的相同位置的值变为self.masked_bias(-1e4)的值.
        '''<1> GPT2Model第一次迭代时输入GPT2Model的forward()函数中的past_key_values参数为None, 此时nd与ns维度才会相等, 
        在nd与ns维度相等的情况下此操作的结果等价于让注意力分数张量w与attention_mask张量相加的结果。
        <2> 若为GPT2Mode第二次及之后的迭代时, nd与ns两个维度相当于1与seq_len+1, 此时对self.bias进行切片操作时, 
        ns - nd等于seq_len+1 - 1即结果为seq_len, 即此时切片操作相当于self.bias[:, :, seq_len : seq_len+1, :seq_len+1],
        此操作的意义在于对此次迭代中, 最新的token的注意力分数上添加GPT2中的下三角形式的注意力遮罩.'''
        if not self.is_cross_attention:
            # if only "normal" attention layer implements causal mask
            # 此时self.bias属性为一个下三角矩阵(对角线下元素全为1, 对角线上元素全为0), 其形状为(1, 1, n_ctx, n_ctx),
            # 也即形状相当于(1, 1, 1024, 1024);但此处对self.bias进行切片操作时, ns - nd等于seq_len+1 - 1即结果为seq_len,
            # 即此时切片操作相当于self.bias[:, :, seq_len : seq_len+1, :seq_len+1]'''此时mask张量(经过大张量self.bias切片获得)的形状为(1, 1, 1, seq_len + 1).'''
            mask = self.bias[:, :, ns - nd: ns, :ns]
            '''此操作的意义在于对此次迭代中, 最新的token的注意力分数上添加GPT2中的下三角形式注意力遮罩.'''
            w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))

        # 让注意力分数张量w与attention_mask张量相加, 以达到让填充特殊符[PAD]处的注意力分数为一个很大的负值的目的,这样在下面将
        # 注意力分数张量w输入Softmax()层计算之后, 填充特殊符[PAD]处的注意力分数将会变为无限接近0的数, 以此让填充特殊符[PAD]
        # 处的注意力分数极小, 其embedding嵌入值基本不会在多头注意力聚合操作中被获取到.
        if attention_mask is not None:
            # Apply the attention mask
            w = w + attention_mask

        # 注意力分数张量w: (batch_size, num_head, 1, sql_len+1).
        # 将注意力分数张量w输入进Softmax()层中进行归一化计算, 计算得出最终的注意力分数,
        # 再将注意力分数张量w输入进Dropout层self.attn_dropout()中进行正则化操作, 防止过拟合.
        w = nn.Softmax(dim=-1)(w)
        w = self.attn_dropout(w)

        # Mask heads if we want to, 对注意力头num_head维度的mask操作.
        if head_mask is not None:
            w = w * head_mask

        # 多头注意力聚合操作: 注意力分数张量w与value张量进行内积
        # 注意力分数张量w形状: (batch_size, num_head, 1, sql_len+1)
        # value张量形状: (batch_size, num_head, sql_len+1, head_features)
        # 多头注意力聚合操作结果张量形状: (batch_size, num_head, 1, head_features), head_features=768.
        outputs = [torch.matmul(w, v)]
        # 若同时返回注意力分数张量w, 则将w张量添加入outputs列表中.
        if output_attentions:
            outputs.append(w)

        return outputs


    def forward(
        self,
        hidden_states,
        layer_past=None,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        use_cache=False,
        output_attentions=False,
    ):
        # <1> 此时的隐藏状态hidden_states的形状为 (batch_size, 1, nx), 此时nx = n_state = n_embed = head_features = 768,
        #     即此时隐藏状态hidden_states的形状为(batch_size, 1, 768)。
        # <2> 此时layer_past为一个存储着past_key张量与past_value张量的大张量, 其
        #     形状为(2, batch_size, num_head, sql_len, head_features).
        # <3> attention_mask张量为注意力遮罩张量, 其让填充特殊符[PAD]处的注意力分数极小,
        #     其embedding嵌入值基本不会在多头注意力聚合操作中被获取到.

        if encoder_hidden_states is not None:
            assert hasattr(
                self, "q_attn"
            ), "If class is used as cross attention, the weights `q_attn` have to be defined. " \
               "Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`."

            '''self.crossattention()的Cross_Attention运算过程则是将LayerNormalization之后的hidden_states通过
            'self.q_attn = Conv1D(n_state, nx)(第168行代码)'将hidden_states的形状由(batch_size,1, 768)投影为(batch_size,1, 768),
            将此投影之后的hidden_states赋值作为query张量;
            再将此时从编码器(encoder)中传过来的编码器隐藏状态encoder_hidden_states通过'self.c_attn = Conv1D(2 * n_state, nx)
            (164行代码)'将encoder_hidden_states的形状由(batch_size, enc_seq_len, 768)投影为(batch_size, enc_seq_len, 2 * 768),
            将投影后的encoder_hidden_states在在第三维度(dim=2)上拆分为两份分别赋为key、value,
            其形状都为(batch_size, enc_seq_len, 768);  此时n_state = nx = num_head*head_features = 768.
            
            之后经过split_heads()函数拆分注意力头之后:
            query张量的形状变为(batch_size, num_head, 1, head_features),
            key张量的形状变为(batch_size, num_head, head_features, enc_seq_len),
            value张量的形状变为(batch_size, num_head, enc_seq_len, head_features).
            
            此时计算出的cross_attention张量形状为(batch_size, num_head, 1, enc_seq_len).'''

            query = self.q_attn(hidden_states)
            key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
            attention_mask = encoder_attention_mask

        else:
            '''此时隐藏状态hidden_states的形状为(batch_size, 1, 768), 将其输入进全连接层self.c_attn中后,Conv1D(3 * n_state, nx)操作(nx=n_state=768)便会将hidden_states的第三维度数由 768维 投影为 3 * 768,
            此时的hidden_states张量的形状为(batch_size, 1, 3 * 768), 最后将hidden_states张量在第三个维度(维度数3 * 768)上
            切分为三块, 将这切分出的三块各当成query, key, value张量, 则每个张量的形状都为(batch_size, 1, 768).
            此时n_state = nx = num_head*head_features = 768.
            
            之后经过split_heads()函数拆分注意力头且key、value张量分别与past_key、past_value张量合并之后:
            query张量的形状变为(batch_size, num_head, 1, head_features),
            key张量的形状变为(batch_size, num_head, head_features, sql_len+1),
            value张量的形状变为(batch_size, num_head, sql_len+1, head_features).'''
            query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)


        '''第一次迭代时query、key、value张量的seq_len维度处的维度数就为seq_len而不是1, 第二次之后seq_len维度的维度数皆为1.'''
        # 此时经过'注意力头拆分函数split_heads()'之后的query、key、value三个张量的形状分别为:
        # query: (batch_size, num_head, 1, head_features)
        # key: (batch_size, num_head, head_features, 1)
        # value: (batch_size, num_head, 1, head_features)
        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value)

        if layer_past is not None:
            '''第一次迭代时query、key、value张量的seq_len维度处的维度数就为seq_len而不是1, 第二次之后seq_len维度的维度数皆为1.'''
            '''<1> 本次迭代中新的key张量
            此时需要通过layer_past[0].transpose(-2, -1)操作将past_key张量的形状变为(batch_size, num_head, head_features, sql_len),
            而此时key张量的形状为(batch_size, num_head, head_features, 1), 这样在下方就方便将past_key张量与key张量在最后
            一个维度(dim=-1)处进行合并, 这样就将当前token的key部分加入了past_key的seq_len中, 以方便模型在后面预测新的token,
            此时新的key张量的形状为: (batch_size, num_head, head_features, sql_len+1), new_seq_len为sql_len+1<2> 本次迭代中新的value张量
            而此时past_value不用变形, 其形状为(batch_size, num_head, sql_len, head_features), 而此时value张量的形状为
            (batch_size, num_head, 1, head_features), 这样在下方就方便将past_value张量与value张量在倒数第二个
            维度(dim=-2)处进行合并, 这样就将当前token的value部分加入了past_value的seq_len中, 以方便模型在后面预测新的token,
            此时新的value张量的形状为: (batch_size, num_head, sql_len+1, head_features), new_seq_len为sql_len+1'''
            past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1]  # transpose back cf below
            key = torch.cat((past_key, key), dim=-1)
            value = torch.cat((past_value, value), dim=-2)

        # config对应的GPT2Config()类中的use_cache默认为True.但此时若为Cross_Attention运算过程, 则此时不会指定use_cache,
        # 而此时use_cache属性即为False(因为Attention类中use_cache属性默认为False, 除非指定config对应的GPT2Config()类
        # 中的use_cache属性其才会为True).
        if use_cache is True:
            # 若use_cache为True, 此时将key张量的最后一个维度与倒数第二个维度互换再与value张量进行stack合并,
            # 此时key.transpose(-2, -1)的形状为(batch_size, num_head, sql_len+1, head_features),
            # 此时torch.stack()操作后的present张量形状为(2, batch_size, num_head, sql_len+1, head_features)'''present张量形状: (2, batch_size, num_head, sql_len+1, head_features),
            即present张量是用来存储此次迭代中的key张量与上一次迭代中的past_key张量(layer_past[0])合并、
            本次迭代的value张量与上一次迭代中的past_value张量(layer_past[1])合并后所得的新的key张量与value张量的.'''
            present = torch.stack((key.transpose(-2, -1), value))  # transpose to have same shapes for stacking
        else:
            present = (None,)


        '''此时query张量形状为: (batch_size, num_head, 1, head_features)
        key张量的形状为: (batch_size, num_head, head_features, sql_len+1)
        value张量的形状为: (batch_size, num_head, sql_len+1, head_features)'''
        # 若output_attentions为True, 则self._attn()函数返回的attn_outputs列表中的第二个值为注意力分数张量w.
        attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions)


        # 此时self._attn()函数返回的attn_outputs列表中的第一个元素为多头注意力聚合操作结果张量a,
        # a张量的形状为(batch_size, num_head, 1, head_features);
        # 若output_attentions为True, 则此时self._attn()函数返回的attn_outputs列表中的第二个元素为
        # 注意力分数张量w, 其形状为(batch_size, num_head, 1, seq_len + 1).
        a = attn_outputs[0]

        '''此时经过'多头注意力头合并函数self.merge_heads()'后的多头注意力聚合操作结果张量a的形状
        变为(batch_size, 1, all_head_size), 其中 all_head_size 等于 num_head * head_features, head_features=768.
        all_head_size维度的维度数为768,等于n_state,也等于nx, 即all_head_size=n_state=nx=768.'''
        a = self.merge_heads(a)

        # 此处self.c_proj()Conv1D(n_state, nx)函数(all_head_size=n_state=nx=768), 相当于一个全连接层的作用,
        # 其将此时的多头注意力聚合操作结果张量a的最后一个维度all_head_size由n_state(768)的维度数投影为nx(768)的维度数.
        a = self.c_proj(a)
        a = self.resid_dropout(a)  # 残差dropout层进行正则化操作, 防止过拟合.

        # 此时多头注意力聚合操作结果张量a的形状为(batch_size, 1, all_head_size),
        # 其中 all_head_size 等于 num_head * head_features;all_head_size维度的维度数为768,
        # 等于n_state,也等于nx, 即all_head_size=n_state=nx=n_embed=768.
        outputs = [a, present] + attn_outputs[1:]

        # 此时返回的outputs列表中:
        # <1> 第一个值为多头注意力聚合操作结果张量a, 形状为(batch_size, 1, all_head_size), all_head_size=n_state=nx=n_embd=768.
        # <2> 第二个值为上方的present张量, 其存储着past_key张量与这次迭代的key张量合并后的新key张量, 以及
        #     past_value张量与这次迭代的value张量合并后的新value张量, 其形状为(2, batch_size, num_head, sql_len+1, head_features).
        # <3> 若output_attentions为True, 则第三个值为attn_outputs列表中的注意力分数张量w,
        #     其形状为(batch_size, num_head, 1, seq_len + 1).
        return outputs  # a, present, (attentions)


  • 4
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值