t5模型为什么可以通过传入past_key和past_value值来进行优化模型

t5模型是常用于文本生成部分的一个模型,也是目前我看到的各个nlp模型之中,唯一完整地使用transformer的所有完整结构(encoder部分加上decoder部分)的一个模型,接下来聊一下t5模型的生成优化过程。

优化的部分

首先对于生成这一块,最慢的速度在于推断而不在于训练,所以t5模型的优化部分在推断内容部分进行优化,推断部分使用的是transformer中的decoder结构,这里我们先看一下t5的decoder主要构成,我将它的结构图简化如下:

                    DecoderLayerTransformers   DecoderLayerAttention
decoder部分的结构图---
                    DecoderCrossTransformers   DecoderCrossAttention

1.decoderlayerattention的拼接

在这里的DecoderLayerAttention中优化的过程,采用的是计算完key和value的值之后,将之前同一网络层的key和value与现在网络层的key和value值拼接在一块,这里的关键点在于每一次对于下一个单词进行预测的时候,实际上只需要预测当前单词的概率即可,并不需要把所有的词语的概率全部都预测出来。
具体分析:
这里输入的query值只是由当前的单词id所构成,而key和value通过拼接之后,实际上跟原始的key的value的值相同

if past_key_value != None:
    key = torch.cat([past_key_value[0],key],dim=2)
if past_key_value != None:
    value = torch.cat([past_key_value[1],value],dim=2)

首先t5模型没有position_embedding,只有word_embedding的情况下,a的embedding和b的embedding拼接在一起,跟a+b的embedding拼接在一起的结果是一样的,这就保证了第一次decoderlayerattention网络层中的输入一样。
其次关键的在于,在attention的公式之中
A t t e n t i o n ( K , Q , V ) = S o f t m a x ( Q K T d k ) V Attention(K,Q,V) = Softmax(\frac{QK^{T}}{\sqrt{d_{k}}})V Attention(K,Q,V)=Softmax(dk QKT)V
对于每一个句子中单独的词来说,这个句子中每一个词的信息都会影响到当前词语的信息,因此在softmax的公式之中,这里的前后信息交互的关键点在于两点
1. Q K T QK^{T} QKT
这里的Q代表每一个词,而K向量代表与其他信息交互的句子中所有的词语,因为我们只需要关注当前的词,所以这里Q只需要取出当前的词即可,因为在我们预测下一个词的概率时,我们只需要知道当前词语的信息。但是这里的K代表的整个句子各个词语的词向量的信息,这些词语会对Q的当前这个词产生影响,因此这里的K必须是完整的,否则会造成只有部分的词向量对当前的K产生影响,造成结果的不准确。
2. S o f t m a x ( Q K T d k ) Softmax(\frac{QK^{T}}{\sqrt{d_{k}}}) Softmax(dk QKT)的结果乘上V
这里的V与上面的K向量同理,是用来交互的,因此V也必须保持向量的完整性,也就需要进行拼接。

2.decoderlayerattention不拼接完整的计算过程

为了读懂不拼接的计算过程,这里我专门比较了一下拼接与不拼接的整个计算过程,下面从整个模型具体的走一遍。
这里我们假设batch_size为1,从第二次生成的输入开始,即max_length = 2,而优化的部分max_length的值永远为1。

步骤1.原始输入

原始输入为input_ids = (1,2,768)(从第二次输入开始),则优化后的输入为input_ids = (1,1,768),这里的(1,1,768)为(1,2,768)的后面维度的矩阵信息。

步骤2.经过query,key,value网络层的输出结果

输入的input_ids分别经过query,key,value三个网络层,得到优化之前的输入为query = (1,2,768),key = (1,2,768),value = (1,2,768),如果是优化后的矩阵向量,则query = (1,1,768),key = (1,1,768),value = (1,1,768)

步骤3.拼接past_key和past_value的值

之前聊到key和value记录了整个句子完整的向量信息,所以key和value的值需要保持完整,因此这里优化的部分需要拼接上之前计算的key和value的信息,得到与未优化之前相同的key和value的值
key = (1,2,768),value = (1,2,768)

步骤4.将query,key,value的值进行拆解
key = key.view(batch_size,-1,self.config.num_heads,self.config.size_per_head)
value = value.view(batch_size,-1,self.config.num_heads,self.config.size_per_head)        

这里只是拆解最后一个维度,跟前面的内容没有关系,
未优化之后拆解出的向量维度为
query = (1,2,12,64),key = (1,2,12,64),value = (1,2,12,64)
优化之后拆解出的向量维度为
query = (1,2,12,64)(这里的1为未优化2的最后一维),key = (1,2,12,64),value = (1,2,12,64)

步骤5.transpose操作
query = query.transpose(1,2)
key = key.transpose(1,2)
value = value.transpose(1,2)

未优化的向量
query = (1,12,2,64),key = (1,12,2,64),value = (1,12,2,64)
优化后的向量
query = (1,12,1,64),key = (1,12,2,64),value = (1,12,2,64)
这里的未优化的query与优化的query的关系是未优化的query中每间隔一波会得到去取优化的query,具体内容如下

        [[[[1.4670e-02,  1.2135e-01, -1.9623e-01, -6.7619e-02, -2.3180e-01,
           -7.0569e-02,  3.0198e-02, -1.6405e-01, -2.1121e-01,  3.5561e-02,
            5.2837e-02, -4.2778e-02,  4.4988e-02, -2.1911e-01,  9.9120e-03,
            1.7280e-01,  9.6410e-02, -7.0861e-02,  1.2775e-01, -1.2981e-01,
           -2.9778e-02, -1.5738e-01, -6.2758e-02,  1.1868e-02, -6.1655e-04,
           -4.7120e-04,  2.2410e-01,  7.7863e-02, -1.5514e-01, -6.7375e-02,
           -2.8295e-02, -3.7375e-02,  6.5420e-02, -7.0116e-02, -9.4598e-02,
            7.6665e-03,  9.0203e-03,  9.5351e-03,  6.0339e-03,  4.3080e-02,
           -3.7418e-03,  7.9097e-03, -2.0718e-02,  6.0211e-02, -2.0378e-02,
            2.0389e-02,  1.0511e-01, -1.4655e-01, -1.4036e-01, -9.6260e-02,
           -1.9263e-03, -9.3218e-02, -1.2116e-02, -2.5003e-01,  1.2911e-01,
            2.0643e-01, -3.3275e-02, -3.7536e-02, -2.0530e-01,  6.0305e-02,
           -5.0186e-02,  9.7535e-02,  1.2266e-01, -5.5510e-02],
           
          (下面为优化的参数部分1)
          [-4.9299e-04, -2.6518e-02, -1.0284e-01, -2.8316e-02,  1.4845e-01,
            2.4035e-02, -1.1961e-01,  1.4806e-02, -1.0239e-02, -1.2139e-01,
           -9.8680e-02,  8.0918e-02,  1.0025e-01,  2.6667e-03, -9.8165e-03,
            4.2049e-02, -7.7295e-03, -1.2139e-01, -9.3672e-02, -9.5083e-02,
            4.5946e-02, -3.0294e-02,  1.1112e-01,  1.0422e-03,  1.6938e-01,
            5.7986e-02, -3.0207e-02,  1.9465e-01,  1.0476e-01,  1.9391e-01,
           -1.1847e-02,  8.8533e-02,  6.9837e-02,  8.0209e-02, -2.8369e-03,
           -1.0628e-01,  3.3870e-02,  2.6993e-02,  9.9969e-02, -5.6613e-02,
           -1.8507e-01, -5.4061e-02,  1.0212e-01,  7.1761e-02, -7.0027e-02,
           -8.0300e-02,  1.0698e-01, -2.2551e-02, -5.0615e-02,  4.5206e-02,
            1.2523e-01,  2.4763e-02, -3.0133e-02,  1.3311e-01, -6.7099e-02,
            2.0527e-02, -7.5543e-02,  2.0281e-02,  3.3036e-02, -4.6203e-02,
           -3.1867e-02,  2.7087e-02,  4.4259e-02, -4.5838e-02]],

         [[-1.4408e-01,  1.5186e-01, -3.5974e-03, -6.0452e-02, -7.9654e-02,
           -4.7384e-02, -2.2503e-02,  3.4365e-01,  1.2694e-01, -4.5820e-02,
            3.0887e-02,  1.0342e-01, -1.0327e-01, -4.4748e-03, -1.8945e-01,
           -2.8120e-02,  1.4307e-01,  2.0421e-02,  1.0495e-01,  3.9390e-02,
           -1.6610e-01, -8.9004e-02, -8.5324e-02,  8.1240e-02, -8.8805e-02,
            5.4473e-02,  2.4430e-01, -1.9869e-01,  1.1048e-01, -5.5874e-02,
            1.5152e-01,  7.5828e-02, -2.1933e-01, -2.9484e-01,  3.2189e-03,
            8.5885e-02, -5.4767e-02, -1.8218e-01,  1.7896e-01, -6.2724e-02,
            4.0281e-02, -1.1383e-01, -1.2164e-01, -2.7832e-01,  1.3230e-01,
           -2.9016e-02, -1.6377e-01,  1.7774e-01,  1.0014e-01,  1.1170e-01,
            8.6232e-03,  2.3320e-01,  5.7124e-03, -4.9258e-02,  1.1669e-02,
           -8.8721e-02, -2.6996e-02, -2.5208e-02,  6.9340e-02, -4.9958e-03,
           -8.7542e-02,  1.2076e-01, -1.4579e-02, -7.5249e-02],
           
          (下面为优化的参数部分2)
          [-7.5944e-02,  3.1948e-02, -1.3581e-01,  1.5007e-01, -2.0339e-02,
           -8.8263e-02, -2.0876e-02,  9.4324e-02,  8.8024e-02, -7.6570e-02,
            2.4468e-02, -1.4054e-01,  1.4406e-01, -6.7122e-02, -3.1915e-01,
            1.3064e-01, -1.0095e-02,  7.6921e-02,  1.3721e-01,  1.6839e-01,
            1.2220e-01,  1.2142e-01, -1.0998e-01,  1.4507e-01,  1.7634e-04,
           -1.6147e-01,  4.2507e-02,  1.6338e-01,  6.4832e-02, -3.5208e-02,
           -9.1079e-02,  5.4273e-04, -9.4308e-03,  5.4123e-02,  4.8480e-02,
            1.0629e-01, -1.5671e-02, -3.0359e-02,  3.4089e-02,  1.0919e-02,
           -7.4085e-02, -8.4118e-02, -3.2656e-02, -5.7829e-02,  7.5287e-02,
            1.6278e-01,  6.6263e-02, -9.5057e-02, -6.9226e-02, -1.0631e-01,
           -3.3601e-02, -1.8329e-02,  1.4080e-01, -2.6989e-02,  2.4369e-01,
           -1.7388e-01, -1.0928e-01, -1.9072e-01,  1.3444e-02,  1.0209e-01,
           -7.9701e-02, -1.8037e-02,  9.1845e-02, -9.0854e-02]],

可以看出这里的参数是隔着相等的

步骤6.相乘得到scores的内容
scores = torch.matmul(
    query,key.transpose(3,2)
)

得到scores优化后的内容和未曾优化后的内容
未曾优化后得到的scores内容

scores = 
tensor([[[[-0.5559,  0.4717],
         (下面是优化后的参数内容)
          [ 1.9015,  2.6016]],
         [[ 2.4006, -5.8395],
         (下面是优化后的参数内容)
          [ 2.7282,  2.8212]],
         [[-4.8745, -3.9924],
        (下面是优化后的参数内容)
          [-0.6129,  0.8469]],
         [[-1.1932, -6.7089],
       (下面是优化后的参数内容)
          [ 3.9086, -7.9154]],
         [[-5.8440,  3.6701],
       (下面是优化后的参数内容)
          [ 8.3781,  7.1280]],
         [[-5.2619, -3.4999],
       (下面是优化后的参数内容)
          [-0.4299, -1.0349]],
         [[-2.2012,  0.9375],
       (下面是优化后的参数内容)
          [ 8.4696,  4.2147]],
         [[-3.0179, -1.1522],
      (下面是优化后的参数内容)
          [ 2.1133,  1.3674]],
         [[ 0.6151,  1.4965],
      (下面是优化后的参数内容)
          [ 3.1316,  5.0822]],
         [[-3.9366, -1.4958],
      (下面是优化后的参数内容)
          [ 1.8155,  0.0675]],
         [[ 3.9352,  0.9258],
      (下面是优化后的参数内容)
          [ 1.8104,  5.5592]],
         [[-5.1672, -2.3799],
     (下面是优化后的参数内容)
          [-1.6228, -1.7248]]],
步骤7.scores+position_bias值:没变化

position_bias相等,所以没变化
这里的值未优化时scores.shape = (2,12,2,2),优化之后scores.shape = (2,12,1,2)间隔相等。

步骤8.计算attn_weights与value相乘
attn_output = torch.matmul(attn_weights,value)

这里的value优化与未优化的参数值是一样的,所以经过相乘之后,得到attn_output的值优化与未优化的还是间隔相等的
这里未优化的情况下,attn_weights = (2,12,2,2),value = (2,12,2,64),相乘之后
attn_output = ( 2 , 12 , 2 , 2 ) ∗ ( 2 , 12 , 2 , 64 ) = ( 2 , 12 , 2 , 64 ) (2,12,2,2)*(2,12,2,64) = (2,12,2,64) (2,12,2,2)(2,12,2,64)=(2,12,2,64)
优化的情况下,attn_weights = (2,12,1,2),value = (2,12,2,64),相乘之后attn_output = ( 2 , 12 , 1 , 2 ) ∗ ( 2 , 12 , 2 , 64 ) = ( 2 , 12 , 1 , 64 ) (2,12,1,2)*(2,12,2,64) = (2,12,1,64) (2,12,1,2)(2,12,2,64)=(2,12,1,64)
此时还是间隔相等

步骤9.计算attn_output
attn_output = attn_output.transpose(1,2)

这里transpose(1,2)之后,如果优化的情况下attn_output = (2,12,2,64)->(2,2,12,64),不优化的情况下attn_output = (2,12,1,64)->(2,1,12,64),此时由于transpose翻转矩阵的存在,本身矩阵由间隔相等变为了最后一维度相等
开头一次transpose将矩阵的形状翻转 Q T Q^T QT,结尾的时候又调用了一次矩阵的翻转
这里的结果翻转最主要的是与开头的翻转进行抵消,开头有这样一段翻转

query = query.transpose(1,2)
key = key.transpose(1,2)
value = value.transpose(1,2)

而结尾的时候将相乘出来的结果翻转一次,能够将开头的翻转抵消掉

attn_output = attn_output.transpose(1,2)

3.decodercrossattention的拼接

这里的拼接过程较为简单,只需要拼接上之前在encoderlayerattention网络层部分的输出即可,所以直接保存encoderlayerattention网络层之前的输出经过当前网络层的内容,避免重复计算。
注意这里保存的也是每一层的encoder编码部分的输出内容
这里的key和value都是当前网络层经过key_layer和value_layer线性层的输出

attn_output = attn_output.transpose(1,2).contiguous().view(batch_size,-1,self.config.num_heads*self.config.size_per_head)

4.简化运算,从另外一个角度来看只取出最后一个维度的计算结果

A t t e n t i o n ( K , Q , V ) = S o f t m a x ( Q K T d k ) V Attention(K,Q,V) = Softmax(\frac{QK^{T}}{\sqrt{d_{k}}})V Attention(K,Q,V)=Softmax(dk QKT)V
这里我们去除掉对于维度变化等结果没有影响的操作过程,将公式简化为如下操作
Q K T V QK^{T}V QKTV
这样原始变化为 Q K T QK^{T} QKT = ( 1 , 12 , 5 , 64 ) ∗ ( 1 , 12 , 64 , 5 ) = ( 1 , 12 , 5 , 5 ) (1,12,5,64)*(1,12,64,5) = (1,12,5,5) (1,12,5,64)(1,12,64,5)=(1,12,5,5)
优化之后的变换为 Q K T QK^{T} QKT = ( 1 , 12 , 1 , 64 ) ∗ ( 1 , 12 , 64 , 5 ) = ( 1 , 12 , 1 , 5 ) (1,12,1,64)*(1,12,64,5) = (1,12,1,5) (1,12,1,64)(1,12,64,5)=(1,12,1,5)
这样优化之后的变换 ( 1 , 12 , 1 , 5 ) (1,12,1,5) (1,12,1,5)正好为 ( 1 , 12 , 5 , 5 ) (1,12,5,5) (1,12,5,5)的最后一波,所以在每一个(1,12)内,这里的优化后的5向量是每隔5个位置出现一波
接着操作 ( Q K T ) V (QK^{T})V (QKT)V,未优化的情况下等于 ( 1 , 12 , 5 , 5 ) ∗ ( 1 , 12 , 5 , 64 ) = ( 1 , 12 , 5 , 64 ) (1,12,5,5)*(1,12,5,64) = (1,12,5,64) (1,12,5,5)(1,12,5,64)=(1,12,5,64)
优化下的情况等于 ( 1 , 12 , 1 , 5 ) ∗ ( 1 , 12 , 5 , 64 ) = ( 1 , 12 , 1 , 64 ) (1,12,1,5)*(1,12,5,64) = (1,12,1,64) (1,12,1,5)(1,12,5,64)=(1,12,1,64)
这里正好(1,64)就是(5,64)的最后一维度,所以每隔5波出现一次,总共出现12次
最后这里翻转就是将这些间隔的相同内容聚集在一起
( 1 , 12 , 5 , 64 ) − > ( 1 , 5 , 12 , 64 ) − > ( 1 , 5 , 768 ) (1,12,5,64) -> (1,5,12,64)->(1,5,768) (1,12,5,64)>(1,5,12,64)>(1,5,768), ( 1 , 12 , 1 , 64 ) − > ( 1 , 1 , 12 , 64 ) − > ( 1 , 5 , 768 ) (1,12,1,64) -> (1,1,12,64)->(1,5,768) (1,12,1,64)>(1,1,12,64)>(1,5,768),因此这里最后一个维度的内容相同

5.随便说说

今天在力扣中无意间看到桶排序,感觉可以用于生成的topk计算算法优化
桶排序算法
不过其实整体数据也没多少,所以感觉topk优化的话提升的效率也不大

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值