MT5ForConditionalGeneration生成模型的推理细节,源码阅读

T5是Google提出的Seq2Seq结构的预训练语言模型,一经提出便登上了GLUE、SuperGLUE等各大NLP榜单第一,而它的升级版本mT5因为用了多国语言语料,在中文任务上可以开箱即用。

HuggingFace的Transformers包里的MT5ForConditionalGeneration,就是MT5生成模型的实现,我们只需调用它的model.generate()函数,就能用mT5模型进行推理和生成,简单易用。model.generate()函数背后的逻辑和内部的实现细节又是什么样的呢,本文带你一窥究竟。

Step 1. 关闭梯度

def generate(self)有装饰器@torch.no_grad(),不计算梯度和反向传播,节约计算资源。

@torch.no_grad()
    def generate(self, ...)

Step 2. Encoder推理

这里用 get_encoder 函数调用 self.encoder,self.encoder的参数都存储在 named_modules()[‘encoder’] 中。

encoder = self.get_encoder()
ModelOutput = encoder(input_ids, return_dict=True, **encoder_kwargs)

self.encoder是T5Stack的一个实例,这个实例的 forward 函数由字嵌入、数值变化、逐层编码、LN&dropout、存储Dict 共5个部分组成。接下来展开介绍:

Part 1. 字嵌入

这里把input_ids传入 self.embed_tokens 函数,调用Embedding(50000, 512)。Embedding(50000, 512) 的参数存储在 named_modules()[‘embed_tokens’]中。它是torch.nn.functional.embedding的一个实例,以lookup table方式快速取得input_ids对应的inputs_embeds。inputs_embeds维度是(16, 104, 512),也就是(bc, seq_len, emb_dim)。

inputs_embeds = self.embed_tokens(input_ids)

Part 2. 数值变化

这部分包括key_value、attention mask、head mask、hidden states的变化。

  • 初始化past_key_value为[None] * 8,也就是[None] * self.encoder层数
    # initialize past_key_values with `None` if past does not exist
    if past_key_values is None:
        past_key_values = [None] * len(self.block)
    
  • 把attention_mask传入get_extend_attention_mask函数,这个函数把attention_mask中值为1的元素变为0,把值为0的元素变为-10000.0,用于忽略encoder中的pad-masked元素。代码中的[:, None, None, :]用于把attention_mask的维度(bc, seq_len)更改为(bc, 1, 1, seq_len),例如(16, 1, 1, 104)。
    # Provided a padding mask of dimensions [batch_size, seq_length]
    # - if the model is a decoder, apply a causal mask in addition to the padding mask
    # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
    if not self.config.is_decoder:
    	extended_attention_mask = attention_mask[:, None, None, :]
    	
    	# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
    
  • 调用get_head_mask函数,初始化head_mask和encoder_head_mask为[None] * self.encoder层数
     if head_mask is not None:
         head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
         if is_attention_chunked is True:
             head_mask = head_mask.unsqueeze(-1)
     else:
         head_mask = [None] * num_hidden_layers
    
  • 把在Part 1. 字嵌入得到的 inputs_embeds 传入self.dropout,self.dropout是torch.nn.functional.dropout的一个实例,以p=0.1的概率丢弃神经元。
    hidden_states = F.dropout(inputs_embeds, self.p, self.training, self.inplace)
    

Part 3. 逐层编码

这部分进入Encoder的各个层,用for循环逐个访问self.encoder的8个层。我们把每层称为一个block。

for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
	layer_outputs = layer_module(hidden_states,
								 attention_mask=extended_attention_mask,
								 use_cache=False,
								 output_attentions=False,
								 ...)

这8个block的内部实现几乎一样,唯一差别是block_0(从0开始计数)比其他block多了一个relative_position_bias。进入到block内部,由Self-Attention Layer和Feed Forward layer两个模块组成。

# Apply Self-Attention layer
hidden_states = self.layer[0](hidden_states, ...)[0]

# Apply Feed Forward layer
hidden_states = self.layer[-1](hidden_states)

下面展开介绍Self-Attention Layer和Feed Forward layer。

Part 3.1 Self-Attention Layer

Self-Attention Layer 由 layernorm、SelfAttention、残差连接共三步组成。

  • 第一步,用self.layernorm调用T5LayerNorm,对传入hidden_states进行层归一化。
    normed_hidden_states = self.layer_norm(hidden_states)
    
    注意,mT5的layer_norm和常规layernorm不同,计算时省去了均值。
    def forward(self, hidden_states):
    	variance = hidden_states.pow(2).mean(-1, keepdim=True)
    	hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
    
    	return self.weight * hidden_states
    
  • 第二步,用self.SelfAttention调用T5Attention
    attention_output = self.SelfAttention(normed_hidden_states, mask=attention_mask, ...)
    
    T5Attention的内部计算和常规的self-attention计算类似,也就是q、k、v分别做线性变换(分别和参数矩阵w_q, w_k, w_v相乘),得到query_states、key_states、value_states,并且维度统一为(bc, n_head, seq_len, head_dim),然后query_states和key_states相乘得到score
    query_states = shape(self.q(hidden_states))  # (batch_size, n_heads, seq_length, dim_per_head)
    key_states = project(hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None)
    value_states = project(hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None)
    
    # compute scores
    scores = torch.matmul(query_states, key_states.transpose(3, 2))
    
    不同之处在于,mT5采用相对位置编码,所以会在self-attention计算时,把relative_position_bias和extended_attention_mask一同加在score上
    position_bias = self.compute_bias(real_seq_length, key_length)
    if extended_attention_mask is not None:
        position_bias = position_bias + extended_attention_mask # (batch_size, n_heads, seq_length, key_length)
    
    scores += position_bias
    
    extended_attention_mask在本文Part 2.数值变化部分讲解过;position_bias是如何计算的呢?这个计算过程是比较特别的。在计算当前token和目标token的attention值时,记录两个token的距离的绝对值,我们不直接使用这个距离值,而是根据距离值的大小进行一定程度的缩小,距离值越大缩小倍数越大,距离值越小缩小倍数越小。具体实现时,采用一种bucket的方法,bucket方法的实现在self._relative_position_bucket函数,代码比较长,这里就不贴了。
    def compute_bias(self, query_length, key_length):
        """ Compute binned relative position bias """
        context_position = torch.arange(query_length, dtype=torch.long)[:, None]
        memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
        relative_position = memory_position - context_position  # shape (query_length, key_length)
        relative_position_bucket = self._relative_position_bucket(
            relative_position,  # shape (query_length, key_length)
            bidirectional=(not self.is_decoder),
            num_buckets=self.relative_attention_num_buckets,
        )
        relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
        values = self.relative_attention_bias(relative_position_bucket)  # shape (query_length, key_length, num_heads)
        values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)
        return values
    
    接下来,和常规self-attention一样,经过softmax得到归一化后的score,再dropout一下,再和value_states相乘,然后维度恢复为(bc, seq_len, n_head*head_dim),例如(16, 104, 384),最后做线性变换(和w_o相乘)。此外有一个细节,计算attention score时不除以dk
    attn_weights = F.softmax(scores.float(), dim=-1).type_as(scores)  # (batch_size, n_heads, seq_length, key_length)
    attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)  # (batch_size, n_heads, seq_length, key_length)
    attn_output = unshape(torch.matmul(attn_weights, value_states))  # (batch_size, seq_length, dim)
    attn_output = self.o(attn_output)
    
  • 第三步,残差连接,把self-attention的输出传入self.dropout,再把dropout的输出与hidden_states直接相加。
    hidden_states = hidden_states + self.dropout(attention_output[0])
    
Part 3.2 Feed Forward Layer

Feed Forward Layer调用T5LayerFF,它由layernorm、DenseReluDense、残差连接共三步组成。

  • 第一步,layernorm,和Part 3.1一样,这里不再赘述。
    forwarded_states = self.layer_norm(hidden_states)
    
  • 第二步,DenseReluDense,它是T5DenseGatedGeluDense的一个实例,它的forward函数由以下3部分组成:
    • gelu_act(wi_0( ~ )) * wi_1( ~ ),"~"表示输入,wi_0和wi_1是两个线性变换层,gelu_act的计算公式是 g e l u _ a c t ( x ) = 1 2 ⋅ x ⋅ ( 1 + tanh ⁡ ( 2 π ⋅ x + 0.44715 ∗ x 3 ) ) gelu\_act(x)=\frac{1}{2}·x·(1+\tanh{(\sqrt{\frac{2}{\pi}}·x+0.44715*x^3)}) gelu_act(x)=21x(1+tanh(π2 x+0.44715x3))
    • dropout层,以p=0.1丢弃神经元,输出维度(16, 104, 1024),也就是(bc, seq_len, fc_dim)
    • wo线性变换层,输出维度(16, 104, 512),也就是(bc, seq_len, model_dim)
    def gelu_act(x):
    	return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
    
    hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
    hidden_linear = self.wi_1(hidden_states)
    hidden_states = hidden_gelu * hidden_linear
    hidden_states = self.dropout(hidden_states)
    hidden_states = self.wo(hidden_states)
    
  • 第三步,残差连接,和Part 3.1一样,这里不再赘述。
    hidden_states = hidden_states + self.dropout(forwarded_states)
    

Part 4. LN&dropout

这部分由layernorm和dropout两部分组成。layernorm和dropout细节在前文介绍过,这里不再重复。

hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)

Part 5. 存储Dict

这里存储的是BaseModelOutputWithPastAndCrossAttentions的一个实例,是一个OrderedDict,包含encoder的输出。

return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states,
									             past_key_values=present_key_value_states,
									             hidden_states=all_hidden_states,
									             attentions=all_attentions,
									             cross_attentions=all_cross_attentions,)

Step 3. Decoder推理

Decoder的input_ids是什么呢?分两种情况:

  • 默认情况下,Decoder的input_ids是[101]*batch_size,101是’[bos]'的token_id;
  • 喂给Decoder一个prompt,例如“上文提到了”,Decoder会接着这个输入进行生成。Decoder的输入是就是这5个字的token_id;
    # set input_ids as decoder_input_ids
    if "decoder_input_ids" in model_kwargs:
        input_ids = model_kwargs.pop("decoder_input_ids")
    else:
        input_ids = self._prepare_decoder_input_ids_for_generation(input_ids, 
    														       decoder_start_token_id=decoder_start_token_id, 
    														       bos_token_id=bos_token_id)
    

除了准备decoder的input_ids,还需要准备一个logits_processor,这里采用MinLengthLogitsProcessor(min_len, eos_id),其中min_len是0,eos_id是102,控制生成序列长度不得小于min_len。

# enforcing a min-length by setting EOS probability to 0.
class MinLengthLogitsProcessor(LogitsProcessor):
	def __call__(self, input_ids, scores):
	    cur_len = input_ids.shape[-1]
	    if cur_len < self.min_length:
	        scores[:, self.eos_token_id] = -float("inf")
	    return scores

准备完毕后,进入生成过程,本文采用greedy_search,greedy_search每个时间步只选择条件概率值最高的生成序列。self.greedy_search内部用while循环,一个字一个字地生成,直到生成的字是’<\s>'或者达到预设定的max_length值,例如40,便会停止生成。所以max_length值越大,推理耗时往往越长。while循环的内部,分为两部分:

第一部分,调用BaseModelOutput,获取Encoder的输出,前文讲到encoder的输出存储在Dict中,所以读字典取值即可。

if encoder_outputs is not None and return_dict and not isinstance(encoder_outputs, BaseModelOutput):
	encoder_outputs = BaseModelOutput(last_hidden_state=encoder_outputs[0],
						              hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
						              attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,)

第二部分,正式进入Decoder解码阶段,调用 self.decoder,self.decoder的参数存储在 named_modules()[‘decoder’] 中。self.decoder是T5Stack的一个实例,这个实例的 forward 函数由字嵌入、数值变化、逐层解码、LN&dropout、存储Dict、获取生成字 共6个部分组成,和self.encoder很相似,但在部分细节处有所不同。下面展开来看:

Part 1. 字嵌入

这里把decoder_input_ids传入 self.embed_tokens 函数,调用Embedding(50000, 512)。Embedding(50000, 512) 的参数存储在 named_modules()[‘embed_tokens’]中。它是torch.nn.functional.embedding的一个实例,以lookup table方式快速取得decoder_input_ids对应的inputs_embeds。初始状态下,inputs_embeds维度是(16, 1, 512),也就是(bc, decoder_seq_len, emb_dim)。

inputs_embeds = self.embed_tokens(input_ids)

此外与encoder不同的是,config.use_cache设置为True,这个下文会用到

Part 2. 数值变化

这部分包括key_value、attention mask、head mask、hidden states的变化。

  • 其中key_value、head mask、hidden states这三个部分和encoder完全相同(详见Step 2. Encoder推理 - Part 2. 数值变化),此处不再重复介绍
  • attention mask的计算
    • 与encoder相同之处:这里把decoder_attention_mask传入get_extend_attention_mask函数,这个函数把decoder_attention_mask中值为1的元素变为0,把值为0的元素变为-10000.0,用于decoder中的pad元素的mask;
    • 与encoder不同之处:
      • 这里额外计算causal_mask,并为decoder_attention_mask乘上causal_mask,用于decoder中的future元素的mask。
        if self.config.is_decoder:
            batch_size, seq_length = input_shape
            seq_ids = torch.arange(seq_length, device=device)
            causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
            extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
        
      • 还调用self.invert_attention_mask函数,计算encoder_extended_attention_mask,把encoder_attention_mask中值为1的元素变为0,把值为0的元素变为-1e9。
        encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
        encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
        

Part 3. 逐层解码

这部分进入Decoder的各个层,用for循环逐个访问self.decoder的8个层,我们把每层称为一个block,这8个block的内部完全一样(都没有relative_position_bias)。

for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
	layer_outputs = layer_module(hidden_states,
								 attention_mask=extended_attention_mask,
								 encoder_hidden_states=encoder_hidden_states,
								 encoder_attention_mask=encoder_extended_attention_mask,
								 use_cache=True,
								 output_attentions=False,
								 ...)

以decoder的block_0为例,进入到block内部,由Self-Attention Layer、Cross_attention Layer和Feed Forward layer 共三个模块组成。

# Apply Self-Attention layer
hidden_states = self.layer[0](hidden_states, ...)[0]

# Apply Cross-Attention layer
cross_attention_outputs = self.layer[1](hidden_states, ...)[0]

# Apply Feed Forward layer
hidden_states = self.layer[-1](hidden_states)

下面展开介绍Self-Attention Layer、Cross_attention Layer和Feed Forward layer。

Part 3.1 Self-Attention Layer

这部分和encoder完全一致(详见Step 2. Encoder推理 - Part 3.1 Self-Attention Layer),不再重复介绍。

Part 3.2 Cross-Attention Layer

cross-attention与self-attention不同在于,qkv不是相同的序列,q是decoder当前时刻的self-Attention Layer的输出,k和v是encoder的输出。

这里是调用T5LayerCrossAttention的forward函数,输入是decoder_input_emb, encoder_attention_mask, 以及key_value_states,这里的key_value_states就是encoder的输出,计算过程分为3部分:

  • 第一步,self.layernorm进行层归一化
    normed_hidden_states = self.layer_norm(hidden_states)
    
  • 第二步,self.EncDecAttention,它是T5Attention的一个实例,与encoder的SelfAttention计算几乎完全相同,不再重复罗列计算过程,不同之处在于:
    • query_states对应于decoder的输入,维度是(bc, n_head, decoder_input_len, head_dim),例如(16, 6, 1, 64);key_states和value_states对应于encoder的输出,维度是(bc, n_head, encoder_input_len, head_dim),例如(16, 6, 104, 64)
    • 由于decoder没有relative_position_bias,这里的positon_bias元素值全为零
    • 由于decoder的use_cache为True,这里会存储present_key_value_states,也就是(key_states和value_states)
  • 第三步,残差连接。
    layer_output = hidden_states + self.dropout(attention_output[0])
    
Part 3.3 Feed Forward Layer

这部分和encoder完全一致(详见Step 2. Encoder推理 - Part 3.2 Feed Forward Layer),不再重复介绍。

Part 4. LN&dropout

这部分和encoder完全一致(详见Step 2. Encoder推理 - Part 4. LN&dropout),不再重复介绍。

Part 5. 存储Dict

这部分和encoder完全一致(详见Step 2. Encoder推理 - Part 5. 存储Dict),不再重复介绍。

Part 6. 获取生成字

第一步,把decoder的输出送入self.lm_head,它是nn.modules.Linear的一个实例,用于得到生成字的logits。logits的维度是(16, 1, 50000),也就是(bc, 1, word_dict_size)。

lm_logits = self.lm_head(sequence_output)

第二步,用[:, -1, :]取当前生成字的logits,然后调用logits_processor预处理这个logits,logits_processor就是前文提到的MinLengthLogitsProcessor,控制生成序列长度不得小于min_len。输出维度为(16, 50000),也就是(bc, word_dict_size)

next_token_logits = outputs.logits[:, -1, :]
next_tokens_scores = logits_processor(input_ids, next_token_logits)

第三步,torch.argmax,得到当前概率值最大的token的id,维度为(16,),也就是(bc),

next_tokens = torch.argmax(next_tokens_scores, dim=-1)

第四步,把这个token_id拼接到input_ids上,用于下一个时刻的while循环。如果这个token_id是’<\s>'或者while循环结束,就会停止继续生成下一个字。

至此,greedy search就结束了,这里的input_ids就是最终的输出。

  • 16
    点赞
  • 35
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

SunnyGJing

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值