transformers5--t5模型中encoder与decoder内容不同解读

关于transformers整体结构的解答可以查看相应的解析:解析网站
本质上t5使用的是编码和解码的操作,transformers的网络结构如下:
transformers编解码过程图片首先需要理解这个transformer对应的结构图,比如我们要想通过输入我爱中国得到输出I love China,那么Inputs输入永远是我爱中国,而Outputs刚开始为+position encoding,接下来产生预测I了之后,继续将Outputs(shifted right)变为(起始符)+I+position encoding,然后产生预测love之后,继续将Outputs(shifted right)变为(起始符)+I+love+positional encoding,以此类推。
由此可见,Inputs的部分始终不变,Outputs(shifted right)部分在不断地变化,从而引起预测结果不断地改变。
此外这种encoder-decoder结构还引出了一种attention的变化,也就是说在t5模型之中,encoder部分的attention与decoder中第二个部分的attention结构一致,decoder attention中第一个部分的attention加入了mask掩码的内容,这与bert4keras中的代码保持一致。

查看transformers库之中的encoder和decoder部分内容的不同

仔细观察发现,t5selfattention和t5crossattention的区别在于t5crossattention之中多加入了两个参数
t5selfattention的内容

self_attention_outputs = self.layer[0](
    hidden_states,
    attention_mask=attention_mask,
    position_bias=position_bias,
    layer_head_mask=layer_head_mask,
    past_key_value=self_attn_past_key_value,
    use_cache=use_cache,
    output_attentions=output_attentions,
)

t5crossattention的内容

cross_attention_outputs = self.layer[1](
    hidden_states,
    key_value_states=encoder_hidden_states,
    attention_mask=encoder_attention_mask,
    position_bias=encoder_decoder_position_bias,
    layer_head_mask=cross_attn_layer_head_mask,
    past_key_value=cross_attn_past_key_value,
    query_length=query_length,
    use_cache=use_cache,
    output_attentions=output_attentions,
)

可以看出来,上文的cross_attention_outputs之中多出了两个参数:key_value_states和query_length,所以这里重点看key_value_states和query_length对cross_attention_outputs造成的影响(即key_value_states和query_length对T5Attention造成的影响)
所以接下来,我们需要进入到代码之中,去查看key_value_states以及query_length对T5Attention造成的影响

综合分析t5LayerSelfAttention和t5LayerCrossAttention的运行的不同

1.先进行encoder的部分
t5LayerSelfAttention的输入:6个T5LayerSelfAttention
输入的内容(1,11,512)
2.再进行decoder的部分
t5LayerSelfAttention的输入:6个T5LayerSelfAttention
输入的内容(1,1,512)
t5LayerCrossAttention的输入:6个T5CrossAttention
输入的内容(1,1,512)
这里面的维度变化在T5ForConditionalGeneration中有改变过

if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
   # get decoder inputs from shifting lm labels to the right
   decoder_input_ids = self._shift_right(labels)

本身decoder_input_ids = (1,11,512),经过_shift_right之后变成了(1,1,512)
T5ForConditionalGeneration->T5PreTrainedModel->PreTrainedModel->GenerationMixin->generate函数
最终找出来是在transformers中的generation_utils.py之中找出来的

if "decoder_input_ids" in model_kwargs:
    input_ids = model_kwargs.pop("decoder_input_ids")
    print('111input_ids = 111')
    print(input_ids)
    print('111111111111111111')
else:
    input_ids = self._prepare_decoder_input_ids_for_generation(
        input_ids, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id
    )
    print('222input_ids = 222')
    print(input_ids)
    print('222222222222222222')

这里输入的各种参数

input_ids = 
tensor([[13959,  1566,    12,  2968,    10,    37,   629,    19,  1627,     5,
             1,     0,     0,     0,     0]])
decoder_start_token_id = 
None
bos_token_id = 
None

出来之后的

input_ids = 222
tensor([[0]])

然后经历6轮的LayerSelfAttention和LayerCrossAttention网络层部分,最后出来了(1,1,512)的tensor内容
decoder出来的内容部分如下:

decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            past_key_values=past_key_values,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
            head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

sequence_output = decoder_outputs[0]

形成的sequence_output = (1,1,512)
输出的各项参数为

decoder_outputs.past_key_values = 

((tensor([[[[-4.8649e-01, -2.3323e+00, -1.1428e+00,  4.3997e-01, -3.7448e+00,
            5.9211e-01,  1.7371e+00,  1.2648e-01, -9.5232e-01,  6.4317e-01,
            6.5032e-02, -2.7661e+00, -2.9257e-01, -2.2728e+00,  1.4708e+00,
            3.6940e+00, -5.9305e-01, -2.2253e+00, -2.2925e+00,  1.2926e+00,
           -1.6622e+00,  1.5806e-01, -9.8186e-01,  6.9422e-01, -2.3424e+00,
           -1.5638e-01, -9.2692e-01,  2.5009e+00,  1.5147e+00, -3.6560e-01,
            3.0006e-01,  9.5156e-01,  2.0886e+00,  3.6983e-01, -1.0588e+00,
            2.6796e+00, -1.4096e+00, -1.1152e+00, -2.3030e+00, -1.3433e+00,
            1.9916e+00,  2.5363e-03, -2.4754e+00,  7.7748e-01, -1.2229e+00,
           -1.9101e+00,  1.9616e+00, -1.2805e+00,  1.0394e+00,  1.6140e-01,
           -2.3916e-01,  2.9783e-01,  1.6426e+00, -1.3518e+00, -1.1187e+00,
           -1.4495e+00, -2.1039e+00,  2.9519e+00, -1.8293e+00,  1.2496e+00,
            6.0215e-01, -2.5693e+00, -1.7539e+00, -5.6927e-01]],
            ............
		  [[ 3.8271e-01,  3.9878e-01,  3.0701e-01,  ...,  2.0659e+00,
            1.1919e+00,  9.1220e-01],
          [ 2.2469e-01, -1.3852e+00, -2.3070e-01,  ..., -4.1294e+00,
           -4.6317e+00, -6.0171e-01],
          [ 4.2913e-02, -2.8669e-01,  1.4512e-01,  ...,  2.4677e-01,
            3.0281e-02,  6.6158e-01],
          ...,
          [ 8.8233e+00,  1.4664e+00, -6.6772e+00,  ...,  5.7047e+00,
            3.9132e+00,  4.7790e+00],
          [ 5.3985e+00, -9.5581e-01, -2.2232e+00,  ...,  7.3522e+00,
            1.5856e+00, -7.5307e+00],
          [-1.8160e+00, -2.0803e+00, -9.2405e-01,  ...,  1.6660e+00,
            1.1615e+00, -1.7454e-01]]]])))
decoder_outputs.hidden_states = None
decoder_outputs.attentions = None
decoder_outputs.cross_attentions = None
encoder_outputs.last_hidden_state = 
tensor([[[ 0.0154,  0.1263,  0.0301,  ..., -0.0117,  0.0373,  0.1015],
         [-0.1926, -0.1285,  0.0228,  ..., -0.0339,  0.0535,  0.1575],
         [ 0.0109, -0.0210,  0.0022,  ...,  0.0008, -0.0056, -0.0393],
         ...,
         [-0.1581, -0.0719,  0.0208,  ..., -0.1778,  0.1037, -0.1703],
         [ 0.0142, -0.1430,  0.0148,  ...,  0.0224, -0.1906, -0.0547],
         [ 0.0756, -0.0119, -0.0273,  ..., -0.0044, -0.0505,  0.0554]]])
encoder_hidden_states = 
None
encoder_attentions = 
None

这里输出的encoder_output.last_hidden_state = (1,11,512)
decoder_output.past_key_values每一个的形状为(1,8,1,64)不知道是用来干什么的
仔细观察decoder_output.past_key_values,发现需要转到t5stack类别的最后面进行查看

print('t5stack past_key_values = ')
print(past_key_values)
print('--------------------------')
return BaseModelOutputWithPastAndCrossAttentions(
    last_hidden_state=hidden_states,
    past_key_values=present_key_value_states,
    hidden_states=all_hidden_states,
    attentions=all_attentions,
    cross_attentions=all_cross_attentions,
)

在进入BaseModelOutputWithPastAndCrossAttentions之前,获得的past_key_values = [None, None, None, None, None, None]
(6个网络层)
也就是说past_key_values为进入BaseModelOutputWithPastAndCrossAttention之后才变换的,即是在modeling_t5.py之中的

return BaseModelOutputWithPastAndCrossAttentions(
    last_hidden_state=hidden_states,
    past_key_values=present_key_value_states,
    hidden_states=all_hidden_states,
    attentions=all_attentions,
    cross_attentions=all_cross_attentions,
)

返回之前,

past_key_values = None

经过阅读代码发现,这里的decoder每次都输入的为一个一维的数值

decoder_input_ids = 
tensor([[0]])
decoder_input_ids = 
tensor([[644]])
decoder_input_ids = 
tensor([[4598]])
decoder_input_ids = 
tensor([[229]])
decoder_input_ids = 
tensor([[19250]])
decoder_input_ids = 
tensor([[5]])

也就是说,这里经历了

decoder_input_ids = self._shift_right(labels)

右移之后

bert4keras t5decoder解读

***inputs = ***
[<tf.Tensor 'Input-Context:0' shape=(?, ?, 768) dtype=float32>, <tf.Tensor 'Decoder-Input-Token:0' shape=(?, ?) dtype=float32>]

t5 decoder的输入为encoder输出和原始的token_ids???

BaseModelOutputWithPastAndCrossAttention解读

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值