论文翻译:MiniViT:Compressing Vision Transformers with Weight Multiplexing(2022CVPR)

在这里插入图片描述

YOLO-Pose论文:MiniViT:Compressing Vision Transformers with Weight Multiplexing
代码已开源:https://github.com/microsoft/Cream

1. 摘要

  Vision Transformer由于其较高的模型性能在计算机视觉领域受到广泛关注。然而,Vision Transformer受到大量参数的影响,限制了它们在内存有限的设备上的适用性。为了缓解这一问题,本文提出了一种新的压缩框架MiniViT,MiniViT能够在保持相同性能的同时实现了Vision Transformer的参数缩减。

  MiniViT的核心思想是将连续Vision TRansformer Block的权重相乘。更具体地说,使权重跨层共享,同时对权重进行转换以增加多样性。Weight distillation也被应用于将知识从Large-scale ViT模型转移到权重复用的紧凑模型。

  综合实验证明了MiniViT的有效性,MiniViT可以将预训练的Swin-B Transformer的尺寸减少48%,同时在ImageNet上Top-1准确率提高了1.0%。此外,使用单层参数,MiniViT能够将DeiT-B压缩9.7倍,从86M到9M的参数,而不会严重影响性能。最后,通过MiniViT在下游基准上的性能来验证其可迁移性。

2. 引言

  大规模预训练的Vision TRansformer,如ViT, CvT和Swin,由于其高性能和下游任务的优越性能,最近引起了极大的关注。然而,它们通常涉及巨大的模型尺寸和大量的训练数据。例如,ViT需要使用3亿张图像来训练一个带有6.32亿参数的巨大模型,才实现了图像分类的最先进性能。同时,Swin使用2-3亿个参数,并在ImageNet-22K上进行了预训练,以在下游检测和分割任务上取得良好的性能。

  数以亿计的参数消耗了相当大的存储和内存,这使得这些模型不适合涉及有限计算资源的应用程序,如边缘和物联网设备,或者需要实时预测的任务。最近的研究表明,大规模的预训练模型是过度参数化的。因此,在不影响这些预训练模型性能的情况下,消除冗余参数和计算开销是必要的。

  权重共享是一种简单且有效的减少模型尺寸的技术。神经网络中权重共享的最初想法是在20世纪90年代由LeCun和Hinton提出的,最近被重新发明用于自然语言处理(NLP)中的Transformer模型压缩。最具代表性的工作是ALBERT,它引入了一种跨层权重共享的方法,以防止参数的数量随着网络深度的增加而增长。该技术可以在不严重影响模型性能的情况下显著降低模型尺寸,从而提高参数效率。然而,Weight sharing在Vision Transformer压缩中的有效性尚未得到很好的探索。

  为了验证这一点,作者在DeiT-S和Swin-B Transformer上执行跨层权重共享。出乎意料的是,这种直接使用权重共享带来了2个严重的问题:

  1. 训练不稳定:作者观察到,Weight sharing across transformer layers使训练变得不稳定,甚至随着共享层数量的增加,导致训练坍塌;
  2. 性能下降:权重共享 Vision Transformer的性能与Vision Transformer相比有明显下降。例如,虽然权重共享可以将模型参数的数量减少4倍,但是它还是带来了Swin-s Transformer精度下降5.6%。

  为了调查这些结果的潜在原因,作者分析了训练过程中梯度的ℓ2-范数以及模型权重前后的中间特征表征之间的相似性。作者发现,在不同层上完全相同的权重是问题的主要原因。特别是,在权重共享过程中,不同Transformer Block中的层归一化不应该完全相同,因为不同层的特征具有不同的尺度和统计量。同时,梯度的ℓ2-范数变大,并在不同层间波动,导致训练不稳定

  最后,Central Kernel Alignment(CKA)值(一个流行的相似性度量)在最后几层显著下降,表明模型在权重共享前后生成的特征图相关性较小,这可能是性能下降的原因。

  在本文中提出了一种新的技术,称为Weight Multiplexing,来解决上述问题。Weight Multiplexing由Weight Transformation和Weight Distillation两个组件组成,共同压缩预训练好的Vision Transformer。
在这里插入图片描述
  Weight Transformation的关键思想是对共享的权值进行转换,使不同层的权值略有不同,如图2所示。该操作不仅可以促进参数的多样性,而且还可以提高训练的稳定性。

  更具体地说,对每个权重共享Transformer Layer的多头自注意力(MSA)模块和多层感知器(MLP)模块进行了简单的线性变换。每一层都包含单独的变换矩阵,因此MLP对应的注意力权重和输出在不同层间是不同的。与共享相同的参数相比,不同层的层归一化也是分开的。因此,可以让Weight Sharing Transformer网络的优化变得更加稳定。

  为了防止性能下降,作者进一步用Weight Distillation设计了Weight Multiplexing,这样嵌入在预训练的模型中的信息可以转移到权重共享的小模型中,这就可以产生更紧凑和更轻的模型。与之前仅依赖于Prediction-Level蒸馏的工作相比,本文的方法同时考虑了Attention-Level和Hidden-State蒸馏,允许较小的模型更好地模拟原始预训练的大型教师模型的行为。
在这里插入图片描述
  实验表明,Weight Multiplexing方法在Baseline上实现了明显的精度提高,并将预训练好的Vision Transformer压缩了2倍。例如,通过提出的Weight Multiplexing,12层的Mini-Swin-B模型比24层Swin-B高0.8%。此外,具有9M参数的MiniDeiT-B在ImageNet上达到了79.8%的Top-1位精度,比DeiT-B小9.7倍。用本文的方法压缩得到的12M微型模型可以很好地迁移到下游目标检测,在COCO验证集上实现了48.6的AP,这与使用28M参数的原始Swin-T相当。

主要贡献

  1. 系统地研究了权重共享在Vision Transformer中的有效性,并分析了权重共享带来问题的原因;
  2. 提出了一种新的通用Vision Transformer压缩框架MiniViT。实验结果表明,MiniViT可以在不损失精度的前提下获得较大的压缩比。此外,MiniViT的性能也可以很好地迁移到下游任务。

3.相关工作

3.1 Vision Transformer

  Transformer虽然最初是为NLP设计的,但最近在计算机视觉方面也显示出了巨大的潜力。Vision Transformer首先将输入图像分割成一系列被称为Token的2D Patch。然后,使用线性投影或堆叠的CNN层将这些Patch展开并转换为d维向量(也称为Patch Embeddings)。为了保留位置信息,Positional Embeddings被添加到Patch Embeddings中。然后将组合的Embeddings输入到Transformer编码器。最后,使用一个线性层来产生最终的分类。

  Transformer编码器由MSA和MLP的交替组成。在每个块前后分别应用层归一化(LN)和残差连接。详细说明MSA和MLP块如下。

MSA

  设

    M
   
  
  
   \mathrm{M}
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord"><span class="mord mathrm">M</span></span></span></span></span></span> 为 Head 的数量, 也称为自注意力模块。给定输入序列 <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    
     Z
    
    
     0
    
   
   
    ∈
   
   
    
     R
    
    
     
      N
     
     
      ×
     
     
      D
     
    
   
  
  
   Z_{0} \in R^{N \times D}
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.83333em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.07153em;">Z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: -0.07153em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">0</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.841331em; vertical-align: 0em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.00773em;">R</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.841331em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.10903em;">N</span><span class="mbin mtight">×</span><span class="mord mathdefault mtight" style="margin-right: 0.02778em;">D</span></span></span></span></span></span></span></span></span></span></span></span></span>, 在第 <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    k
   
  
  
   \mathrm{k}
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.69444em; vertical-align: 0em;"></span><span class="mord"><span class="mord mathrm">k</span></span></span></span></span></span> 个 Head 中, 通 过线性投影生成 Query、Key 和 Value, 分别用 <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    
     Q
    
    
     k
    
   
   
    ∈
   
   
    
     R
    
    
     
      N
     
     
      ×
     
     
      d
     
    
   
   
    、
   
   
    
     K
    
    
     k
    
   
   
    ∈
   
   
    
     R
    
    
     
      N
     
     
      ×
     
     
      d
     
    
   
  
  
   Q_{k} \in R^{N \times d} 、 K_{k} \in R^{N \times d}
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.87777em; vertical-align: -0.19444em;"></span><span class="mord"><span class="mord mathdefault">Q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.03148em;">k</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.999108em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.00773em;">R</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.849108em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.10903em;">N</span><span class="mbin mtight">×</span><span class="mord mathdefault mtight">d</span></span></span></span></span></span></span></span></span><span class="mord cjk_fallback">、</span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.07153em;">K</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: -0.07153em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.03148em;">k</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.849108em; vertical-align: 0em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.00773em;">R</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.849108em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.10903em;">N</span><span class="mbin mtight">×</span><span class="mord mathdefault mtight">d</span></span></span></span></span></span></span></span></span></span></span></span></span> 和 <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    
     V
    
    
     k
    
   
   
    ∈
   
   
    
     R
    
    
     
      N
     
     
      ×
     
     
      d
     
    
   
  
  
   V_{k} \in R^{N \times d}
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.83333em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.22222em;">V</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: -0.22222em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.03148em;">k</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.849108em; vertical-align: 0em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.00773em;">R</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.849108em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.10903em;">N</span><span class="mbin mtight">×</span><span class="mord mathdefault mtight">d</span></span></span></span></span></span></span></span></span></span></span></span></span> 表 示, 其中 <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    N
   
  
  
   \mathrm{N}
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord"><span class="mord mathrm">N</span></span></span></span></span></span> 是 Token 的数量。 <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    D
   
  
  
   \mathrm{D}
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord"><span class="mord mathrm">D</span></span></span></span></span></span> 和d分别是 Patch Embeddings 和 Q-K-V矩阵 的维数。然后, 计 算序列中每个位置的所有值的加权和。这些权重被称为注意力权重, 用 <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    
     A
    
    
     k
    
   
  
  
   A_{k}
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.83333em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathdefault">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.03148em;">k</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span> 表示, 是基于序列 中 2 个元素之间的成对相似性, 即<br> <span class="katex--display"><span class="katex-display"><span class="katex"><span class="katex-mathml">
 
  
   
    
     
      
       
      
     
     
      
       
        
        
         
          h
         
         
          k
         
        
        
         =
        
        
         
          A
         
         
          k
         
        
        
         
          V
         
         
          k
         
        
        
         ,
        
        
         &nbsp;and&nbsp;
        
       
      
     
    
    
     
      
       
      
     
     
      
       
        
        
         
          A
         
         
          k
         
        
        
         =
        
        
         softmax
        
        
         ⁡
        
        
         
          (
         
         
          
           
            
             Q
            
            
             k
            
           
           
            
             K
            
            
             k
            
            
             T
            
           
          
          
           
            d
           
          
         
         
          )
         
        
       
      
     
    
   
   
     <span class="MathJax_Preview" style="color: inherit; --darkreader-inline-color: inherit; display: none;" data-darkreader-inline-color=""></span><div class="MathJax_Display"><span class="MathJax MathJax_FullWidth" id="MathJax-Element-1-Frame" tabindex="0" style="position: relative;" data-mathml="<math xmlns=&quot;http://www.w3.org/1998/Math/MathML&quot; display=&quot;block&quot;><mtable columnalign=&quot;right left right left right left right left right left right left&quot; rowspacing=&quot;3pt&quot; columnspacing=&quot;0em 2em 0em 2em 0em 2em 0em 2em 0em 2em 0em&quot; displaystyle=&quot;true&quot;><mtr><mtd /><mtd><msub><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mi mathvariant=&quot;bold&quot;>h</mi></mrow><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mi>k</mi></mrow></msub><mo>=</mo><msub><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mi mathvariant=&quot;bold&quot;>A</mi></mrow><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mi>k</mi></mrow></msub><msub><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mi mathvariant=&quot;bold&quot;>V</mi></mrow><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mi>k</mi></mrow></msub><mo>,</mo><mtext>&amp;#xA0;and&amp;#xA0;</mtext></mtd></mtr><mtr><mtd /><mtd><msub><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mi mathvariant=&quot;bold&quot;>A</mi></mrow><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mi>k</mi></mrow></msub><mo>=</mo><mi>softmax</mi><mo>&amp;#x2061;</mo><mrow><mo>(</mo><mfrac><mrow><msub><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mi mathvariant=&quot;bold&quot;>Q</mi></mrow><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mi>k</mi></mrow></msub><msubsup><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mi mathvariant=&quot;bold&quot;>K</mi></mrow><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mi>k</mi></mrow><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mi>T</mi></mrow></msubsup></mrow><msqrt><mi>d</mi></msqrt></mfrac><mo>)</mo></mrow></mtd></mtr></mtable></math>" role="presentation"><nobr aria-hidden="true"><span class="math" id="MathJax-Span-1" style="width: 100%; display: inline-block; min-width: 11.19em;"><span style="display: inline-block; position: relative; width: 100%; height: 0px; font-size: 102%;"><span style="position: absolute; clip: rect(2.787em, 1010.68em, 7.393em, -999.997em); top: -5.313em; left: 0em; width: 100%;"><span class="mrow" id="MathJax-Span-2"><span style="display: inline-block; position: relative; width: 100%; height: 0px;"><span style="position: absolute; clip: rect(2.787em, 1010.68em, 7.393em, -999.997em); top: -5.313em; left: 50%; margin-left: -5.465em;"><span class="mtable" id="MathJax-Span-3" style="padding-left: 0.154em;"><span style="display: inline-block; position: relative; width: 10.785em; height: 0px;"><span style="position: absolute; clip: rect(2.23em, 1000em, 4.812em, -999.997em); top: -3.997em; left: 0em;"><span style="display: inline-block; position: relative; width: 0em; height: 0px;"><span style="position: absolute; clip: rect(3.85em, 1000em, 4.154em, -999.997em); top: -5.617em; right: 0em;"><span class="mtd" id="MathJax-Span-4"><span class="mrow" id="MathJax-Span-5"></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; clip: rect(3.85em, 1000em, 4.154em, -999.997em); top: -3.339em; right: 0em;"><span class="mtd" id="MathJax-Span-32"><span class="mrow" id="MathJax-Span-33"></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; clip: rect(2.787em, 1010.53em, 7.292em, -999.997em); top: -5.212em; left: 0em;"><span style="display: inline-block; position: relative; width: 10.785em; height: 0px;"><span style="position: absolute; width: 100%; clip: rect(3.141em, 1007.29em, 4.356em, -999.997em); top: -5.617em; left: 0em;"><span class="mtd" id="MathJax-Span-6"><span class="mrow" id="MathJax-Span-7"><span style="display: inline-block; position: relative; width: 100%; height: 0px;"><span style="position: absolute; clip: rect(3.141em, 1007.29em, 4.356em, -999.997em); top: -3.997em; left: 50%; margin-left: -3.744em;"><span class="msubsup" id="MathJax-Span-8"><span style="display: inline-block; position: relative; width: 1.066em; height: 0px;"><span style="position: absolute; clip: rect(3.141em, 1000.61em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="texatom" id="MathJax-Span-9"><span class="mrow" id="MathJax-Span-10"><span style="display: inline-block; position: relative; width: 0.661em; height: 0px;"><span style="position: absolute; clip: rect(3.141em, 1000.61em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-11" style="font-family: MathJax_Main-bold;">h</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.845em; left: 0.661em;"><span class="texatom" id="MathJax-Span-12"><span class="mrow" id="MathJax-Span-13"><span style="display: inline-block; position: relative; width: 0.357em; height: 0px;"><span style="position: absolute; clip: rect(3.344em, 1000.36em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-14" style="font-size: 70.7%; font-family: MathJax_Math-italic;">k</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span class="mo" id="MathJax-Span-15" style="font-family: MathJax_Main; padding-left: 0.256em;">=</span><span class="msubsup" id="MathJax-Span-16" style="padding-left: 0.256em;"><span style="display: inline-block; position: relative; width: 1.319em; height: 0px;"><span style="position: absolute; clip: rect(3.141em, 1000.81em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="texatom" id="MathJax-Span-17"><span class="mrow" id="MathJax-Span-18"><span style="display: inline-block; position: relative; width: 0.863em; height: 0px;"><span style="position: absolute; clip: rect(3.141em, 1000.81em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-19" style="font-family: MathJax_Main-bold;">A</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.845em; left: 0.863em;"><span class="texatom" id="MathJax-Span-20"><span class="mrow" id="MathJax-Span-21"><span style="display: inline-block; position: relative; width: 0.357em; height: 0px;"><span style="position: absolute; clip: rect(3.344em, 1000.36em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-22" style="font-size: 70.7%; font-family: MathJax_Math-italic;">k</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span class="msubsup" id="MathJax-Span-23"><span style="display: inline-block; position: relative; width: 1.319em; height: 0px;"><span style="position: absolute; clip: rect(3.141em, 1000.86em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="texatom" id="MathJax-Span-24"><span class="mrow" id="MathJax-Span-25"><span style="display: inline-block; position: relative; width: 0.863em; height: 0px;"><span style="position: absolute; clip: rect(3.141em, 1000.86em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-26" style="font-family: MathJax_Main-bold;">V</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.845em; left: 0.863em;"><span class="texatom" id="MathJax-Span-27"><span class="mrow" id="MathJax-Span-28"><span style="display: inline-block; position: relative; width: 0.357em; height: 0px;"><span style="position: absolute; clip: rect(3.344em, 1000.36em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-29" style="font-size: 70.7%; font-family: MathJax_Math-italic;">k</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span class="mo" id="MathJax-Span-30" style="font-family: MathJax_Main;">,</span><span class="mtext" id="MathJax-Span-31" style="font-family: MathJax_Main; padding-left: 0.154em;">&nbsp;and&nbsp;</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; width: 100%; clip: rect(2.078em, 1010.53em, 5.419em, -999.997em); top: -3.339em; left: 0em;"><span class="mtd" id="MathJax-Span-34"><span class="mrow" id="MathJax-Span-35"><span style="display: inline-block; position: relative; width: 100%; height: 0px;"><span style="position: absolute; clip: rect(2.078em, 1010.53em, 5.419em, -999.997em); top: -3.997em; left: 50%; margin-left: -5.363em;"><span class="msubsup" id="MathJax-Span-36"><span style="display: inline-block; position: relative; width: 1.319em; height: 0px;"><span style="position: absolute; clip: rect(3.141em, 1000.81em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="texatom" id="MathJax-Span-37"><span class="mrow" id="MathJax-Span-38"><span style="display: inline-block; position: relative; width: 0.863em; height: 0px;"><span style="position: absolute; clip: rect(3.141em, 1000.81em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-39" style="font-family: MathJax_Main-bold;">A</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.845em; left: 0.863em;"><span class="texatom" id="MathJax-Span-40"><span class="mrow" id="MathJax-Span-41"><span style="display: inline-block; position: relative; width: 0.357em; height: 0px;"><span style="position: absolute; clip: rect(3.344em, 1000.36em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-42" style="font-size: 70.7%; font-family: MathJax_Math-italic;">k</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span class="mo" id="MathJax-Span-43" style="font-family: MathJax_Main; padding-left: 0.256em;">=</span><span class="mi" id="MathJax-Span-44" style="font-family: MathJax_Main; padding-left: 0.256em;">softmax</span><span class="mo" id="MathJax-Span-45"></span><span class="mrow" id="MathJax-Span-46"><span class="mo" id="MathJax-Span-47" style="vertical-align: 0em;"><span style="font-family: MathJax_Size4;">(</span></span><span class="mfrac" id="MathJax-Span-48"><span style="display: inline-block; position: relative; width: 2.888em; height: 0px; margin-right: 0.104em; margin-left: 0.104em;"><span style="position: absolute; clip: rect(2.939em, 1002.79em, 4.407em, -999.997em); top: -4.705em; left: 50%; margin-left: -1.364em;"><span class="mrow" id="MathJax-Span-49"><span style="display: inline-block; position: relative; width: 2.787em; height: 0px;"><span style="position: absolute; clip: rect(2.939em, 1002.79em, 4.407em, -999.997em); top: -3.997em; left: 0em;"><span class="msubsup" id="MathJax-Span-50"><span style="display: inline-block; position: relative; width: 1.319em; height: 0px;"><span style="position: absolute; clip: rect(3.141em, 1000.81em, 4.356em, -999.997em); top: -3.997em; left: 0em;"><span class="texatom" id="MathJax-Span-51"><span class="mrow" id="MathJax-Span-52"><span style="display: inline-block; position: relative; width: 0.863em; height: 0px;"><span style="position: absolute; clip: rect(3.141em, 1000.81em, 4.356em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-53" style="font-family: MathJax_Main-bold;">Q</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.744em; left: 0.863em;"><span class="texatom" id="MathJax-Span-54"><span class="mrow" id="MathJax-Span-55"><span style="display: inline-block; position: relative; width: 0.357em; height: 0px;"><span style="position: absolute; clip: rect(3.344em, 1000.36em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-56" style="font-size: 70.7%; font-family: MathJax_Math-italic;">k</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span class="msubsup" id="MathJax-Span-57"><span style="display: inline-block; position: relative; width: 1.471em; height: 0px;"><span style="position: absolute; clip: rect(3.141em, 1000.86em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="texatom" id="MathJax-Span-58"><span class="mrow" id="MathJax-Span-59"><span style="display: inline-block; position: relative; width: 0.914em; height: 0px;"><span style="position: absolute; clip: rect(3.141em, 1000.86em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-60" style="font-family: MathJax_Main-bold;">K</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; clip: rect(3.394em, 1000.56em, 4.154em, -999.997em); top: -4.402em; left: 0.914em;"><span class="texatom" id="MathJax-Span-61"><span class="mrow" id="MathJax-Span-62"><span style="display: inline-block; position: relative; width: 0.509em; height: 0px;"><span style="position: absolute; clip: rect(3.394em, 1000.51em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-63" style="font-size: 70.7%; font-family: MathJax_Math-italic;">T<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.104em;"></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; clip: rect(3.344em, 1000.46em, 4.154em, -999.997em); top: -3.744em; left: 0.914em;"><span class="texatom" id="MathJax-Span-64"><span class="mrow" id="MathJax-Span-65"><span style="display: inline-block; position: relative; width: 0.357em; height: 0px;"><span style="position: absolute; clip: rect(3.344em, 1000.36em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-66" style="font-size: 70.7%; font-family: MathJax_Math-italic;">k</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; clip: rect(2.989em, 1001.47em, 4.305em, -999.997em); top: -3.187em; left: 50%; margin-left: -0.706em;"><span class="msqrt" id="MathJax-Span-67"><span style="display: inline-block; position: relative; width: 1.471em; height: 0px;"><span style="position: absolute; clip: rect(3.141em, 1000.51em, 4.154em, -999.997em); top: -3.997em; left: 0.812em;"><span class="mrow" id="MathJax-Span-68"><span style="display: inline-block; position: relative; width: 0.509em; height: 0px;"><span style="position: absolute; clip: rect(3.141em, 1000.51em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-69" style="font-family: MathJax_Math-italic;">d<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; clip: rect(3.597em, 1000.61em, 3.9em, -999.997em); top: -4.554em; left: 0.812em;"><span style="display: inline-block; position: relative; width: 0.61em; height: 0px;"><span style="position: absolute; font-family: MathJax_Main; top: -3.997em; left: -0.099em;">−<span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; font-family: MathJax_Main; top: -3.997em; left: -0.048em;">−<span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; clip: rect(3.04em, 1000.86em, 4.356em, -999.997em); top: -4.047em; left: 0em;"><span style="font-family: MathJax_Main;">√</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; clip: rect(0.863em, 1002.89em, 1.217em, -999.997em); top: -1.263em; left: 0em;"><span style="display: inline-block; overflow: hidden; vertical-align: 0em; border-top: 1.3px solid; width: 2.888em; height: 0px; --darkreader-inline-border-top: initial;" data-darkreader-inline-border-top=""></span><span style="display: inline-block; width: 0px; height: 1.066em;"></span></span></span></span><span class="mo" id="MathJax-Span-70" style="vertical-align: 0em;"><span style="font-family: MathJax_Size4;">)</span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span><span style="display: inline-block; width: 0px; height: 5.217em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 5.318em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 5.318em;"></span></span></span><span style="display: inline-block; overflow: hidden; vertical-align: -2.011em; border-left: 0px solid; width: 0px; height: 4.495em; --darkreader-inline-border-left: initial;" data-darkreader-inline-border-left=""></span></span></nobr><span class="MJX_Assistive_MathML MJX_Assistive_MathML_Block" role="presentation"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><mtable columnalign="right left right left right left right left right left right left" rowspacing="3pt" columnspacing="0em 2em 0em 2em 0em 2em 0em 2em 0em 2em 0em" displaystyle="true"><mtr><mtd></mtd><mtd><msub><mrow class="MJX-TeXAtom-ORD"><mi mathvariant="bold">h</mi></mrow><mrow class="MJX-TeXAtom-ORD"><mi>k</mi></mrow></msub><mo>=</mo><msub><mrow class="MJX-TeXAtom-ORD"><mi mathvariant="bold">A</mi></mrow><mrow class="MJX-TeXAtom-ORD"><mi>k</mi></mrow></msub><msub><mrow class="MJX-TeXAtom-ORD"><mi mathvariant="bold">V</mi></mrow><mrow class="MJX-TeXAtom-ORD"><mi>k</mi></mrow></msub><mo>,</mo><mtext>&nbsp;and&nbsp;</mtext></mtd></mtr><mtr><mtd></mtd><mtd><msub><mrow class="MJX-TeXAtom-ORD"><mi mathvariant="bold">A</mi></mrow><mrow class="MJX-TeXAtom-ORD"><mi>k</mi></mrow></msub><mo>=</mo><mi>softmax</mi><mo>⁡</mo><mrow><mo>(</mo><mfrac><mrow><msub><mrow class="MJX-TeXAtom-ORD"><mi mathvariant="bold">Q</mi></mrow><mrow class="MJX-TeXAtom-ORD"><mi>k</mi></mrow></msub><msubsup><mrow class="MJX-TeXAtom-ORD"><mi mathvariant="bold">K</mi></mrow><mrow class="MJX-TeXAtom-ORD"><mi>k</mi></mrow><mrow class="MJX-TeXAtom-ORD"><mi>T</mi></mrow></msubsup></mrow><msqrt><mi>d</mi></msqrt></mfrac><mo>)</mo></mrow></mtd></mtr></mtable></math></span></span></div><script type="math/tex; mode=display" id="MathJax-Element-1">\begin{aligned} &\mathbf{h}_{k}=\mathbf{A}_{k} \mathbf{V}_{k}, \text { and } \\ &\mathbf{A}_{k}=\operatorname{softmax}\left(\frac{\mathbf{Q}_{k} \mathbf{K}_{k}^{T}}{\sqrt{d}}\right) \end{aligned}</script> 
   
  
 </span><span class="katex-html"><span class="base"><span class="strut" style="height: 4.26836em; vertical-align: -1.88418em;"></span><span class="mord"><span class="mtable"><span class="col-align-r"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 2.38418em;"><span class="" style="top: -5.06251em;"><span class="pstrut" style="height: 3.51833em;"></span><span class="mord"></span></span><span class="" style="top: -2.88418em;"><span class="pstrut" style="height: 3.51833em;"></span><span class="mord"></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 1.88418em;"><span class=""></span></span></span></span></span><span class="col-align-l"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 2.38418em;"><span class="" style="top: -5.06251em;"><span class="pstrut" style="height: 3.51833em;"></span><span class="mord"><span class="mord"></span><span class="mord"><span class="mord"><span class="mord mathbf">h</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.03148em;">k</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord"><span class="mord"><span class="mord mathbf">A</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.03148em;">k</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mord"><span class="mord"><span class="mord mathbf" style="margin-right: 0.01597em;">V</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.03148em;">k</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord text"><span class="mord">&nbsp;and&nbsp;</span></span></span></span><span class="" style="top: -2.88418em;"><span class="pstrut" style="height: 3.51833em;"></span><span class="mord"><span class="mord"></span><span class="mord"><span class="mord"><span class="mord mathbf">A</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.03148em;">k</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mop"><span class="mord mathrm">s</span><span class="mord mathrm">o</span><span class="mord mathrm" style="margin-right: 0.07778em;">f</span><span class="mord mathrm">t</span><span class="mord mathrm">m</span><span class="mord mathrm">a</span><span class="mord mathrm">x</span></span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="minner"><span class="mopen delimcenter" style="top: 0em;"><span class="delimsizing size3">(</span></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.51833em;"><span class="" style="top: -2.17778em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.93222em;"><span class="svg-align" style="top: -3em;"><span class="pstrut" style="height: 3em;"></span><span class="mord" style="padding-left: 0.833em;"><span class="mord mathdefault">d</span></span></span><span class="" style="top: -2.89222em;"><span class="pstrut" style="height: 3em;"></span><span class="hide-tail" style="min-width: 0.853em; height: 1.08em;">
                        <svg width="400em" height="1.08em" viewBox="0 0 400000 1080" preserveAspectRatio="xMinYMin slice">
                         <path d="M95,702c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,

-10,-9.5,-14c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54c44.2,-33.3,65.8,
-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10s173,378,173,378c0.7,0,
35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429c69,-144,104.5,-217.7,106.5,
-221c5.3,-9.3,12,-14,20,-14H400000v40H845.2724s-225.272,467,-225.272,467
s-235,486,-235,486c-2.7,4.7,-9,7,-19,7c-6,0,-10,-1,-12,-3s-194,-422,-194,-422
s-65,47,-65,47z M834 80H400000v40H845z">
QkKkT)
  其中, 对输入矩阵的每一行进行

    softmax
   
   
    ⁡
   
   
    (
   
   
    ⋅
   
   
    )
   
  
  
   \operatorname{softmax}(\cdot)
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mop"><span class="mord mathrm">s</span><span class="mord mathrm">o</span><span class="mord mathrm" style="margin-right: 0.07778em;">f</span><span class="mord mathrm">t</span><span class="mord mathrm">m</span><span class="mord mathrm">a</span><span class="mord mathrm">x</span></span><span class="mopen">(</span><span class="mord">⋅</span><span class="mclose">)</span></span></span></span></span> 操作。最后, 将一个全连接层应用于所有 Head 的输出的连接。</p> 
MLP

  MLP 块由 2 个

    F
   
   
    C
   
  
  
   F C
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord mathdefault" style="margin-right: 0.13889em;">F</span><span class="mord mathdefault" style="margin-right: 0.07153em;">C</span></span></span></span></span> 层组成, 其激活函数用 <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    σ
   
   
    (
   
   
    ⋅
   
   
    )
   
  
  
   \sigma(\cdot)
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathdefault" style="margin-right: 0.03588em;">σ</span><span class="mopen">(</span><span class="mord">⋅</span><span class="mclose">)</span></span></span></span></span> 表示, 通常为 GELU。设 <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    Y
   
   
    ∈
   
   
    
     R
    
    
     
      N
     
     
      ×
     
     
      d
     
    
   
  
  
   Y \in R^{N \times d}
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.72243em; vertical-align: -0.0391em;"></span><span class="mord mathdefault" style="margin-right: 0.22222em;">Y</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.849108em; vertical-align: 0em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.00773em;">R</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.849108em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.10903em;">N</span><span class="mbin mtight">×</span><span class="mord mathdefault mtight">d</span></span></span></span></span></span></span></span></span></span></span></span></span> 为 MLP 的输入。 MLP 的输出可以表示为<br> <span class="katex--display"><span class="katex-display"><span class="katex"><span class="katex-mathml">
 
  
   
    
     H
    
    
     =
    
    
     σ
    
    
     
      (
     
     
      Y
     
     
      
       W
      
      
       
        (
       
       
        1
       
       
        )
       
      
     
     
      +
     
     
      
       b
      
      
       
        (
       
       
        1
       
       
        )
       
      
     
     
      )
     
    
    
     
      W
     
     
      
       (
      
      
       2
      
      
       )
      
     
    
    
     +
    
    
     
      b
     
     
      
       (
      
      
       2
      
      
       )
      
     
    
   
   
     \mathbf{H}=\sigma\left(\mathbf{Y} \mathbf{W}^{(1)}+\mathbf{b}^{(1)}\right) \mathbf{W}^{(2)}+\mathbf{b}^{(2)} 
   
  
 </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.68611em; vertical-align: 0em;"></span><span class="mord"><span class="mord mathbf">H</span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.80002em; vertical-align: -0.65002em;"></span><span class="mord mathdefault" style="margin-right: 0.03588em;">σ</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="minner"><span class="mopen delimcenter" style="top: 0em;"><span class="delimsizing size2">(</span></span><span class="mord"><span class="mord mathbf" style="margin-right: 0.02875em;">Y</span></span><span class="mord"><span class="mord"><span class="mord mathbf" style="margin-right: 0.01597em;">W</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.938em;"><span class="" style="top: -3.113em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight">1</span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mord"><span class="mord"><span class="mord mathbf">b</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.938em;"><span class="" style="top: -3.113em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight">1</span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span><span class="mclose delimcenter" style="top: 0em;"><span class="delimsizing size2">)</span></span></span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord"><span class="mord"><span class="mord mathbf" style="margin-right: 0.01597em;">W</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.938em;"><span class="" style="top: -3.113em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight">2</span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 0.938em; vertical-align: 0em;"></span><span class="mord"><span class="mord"><span class="mord mathbf">b</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.938em;"><span class="" style="top: -3.113em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight">2</span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span></span></span></span></span></span><br> 其中, <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    
     W
    
    
     
      (
     
     
      1
     
     
      )
     
    
   
   
    ∈
   
   
    
     R
    
    
     
      d
     
     
      ×
     
     
      
       d
      
      
       ′
      
     
    
   
   
    、
   
   
    
     b
    
    
     
      (
     
     
      1
     
     
      )
     
    
   
   
    ∈
   
   
    
     R
    
    
     
      d
     
     
      ′
     
    
   
   
    、
   
   
    
     W
    
    
     
      (
     
     
      2
     
     
      )
     
    
   
   
    ∈
   
   
    
     R
    
    
     
      d
     
     
      ×
     
     
      
       d
      
      
       ′
      
     
    
   
   
    、
   
   
    
     b
    
    
     
      (
     
     
      2
     
     
      )
     
    
   
   
    ∈
   
   
    
     R
    
    
     
      d
     
     
      ′
     
    
   
  
  
   W^{(1)} \in R^{d \times d^{\prime}} 、 b^{(1)} \in R^{d^{\prime}} 、 W^{(2)} \in R^{d \times d^{\prime}} 、 b^{(2)} \in R^{d^{\prime}}
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.9271em; vertical-align: -0.0391em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.13889em;">W</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.888em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight">1</span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.98158em; vertical-align: -0.0391em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.00773em;">R</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.94248em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">d</span><span class="mbin mtight">×</span><span class="mord mtight"><span class="mord mathdefault mtight">d</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.827829em;"><span class="" style="top: -2.931em; margin-right: 0.0714286em;"><span class="pstrut" style="height: 2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="mord cjk_fallback">、</span><span class="mord"><span class="mord mathdefault">b</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.888em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight">1</span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.98158em; vertical-align: -0.0391em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.00773em;">R</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.94248em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mathdefault mtight">d</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.827829em;"><span class="" style="top: -2.931em; margin-right: 0.0714286em;"><span class="pstrut" style="height: 2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="mord cjk_fallback">、</span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.13889em;">W</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.888em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight">2</span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.98158em; vertical-align: -0.0391em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.00773em;">R</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.94248em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">d</span><span class="mbin mtight">×</span><span class="mord mtight"><span class="mord mathdefault mtight">d</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.827829em;"><span class="" style="top: -2.931em; margin-right: 0.0714286em;"><span class="pstrut" style="height: 2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="mord cjk_fallback">、</span><span class="mord"><span class="mord mathdefault">b</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.888em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight">2</span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.94248em; vertical-align: 0em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.00773em;">R</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.94248em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mathdefault mtight">d</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.827829em;"><span class="" style="top: -2.931em; margin-right: 0.0714286em;"><span class="pstrut" style="height: 2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span> 分别为第1层和第2层的权重和偏 差。通常设置 <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    
     d
    
    
     ′
    
   
   
    &gt;
   
   
    d
   
  
  
   \mathrm{d}^{\prime}&gt;\mathrm{d}
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.790992em; vertical-align: -0.0391em;"></span><span class="mord"><span class="mord"><span class="mord mathrm">d</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.751892em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">&gt;</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.69444em; vertical-align: 0em;"></span><span class="mord"><span class="mord mathrm">d</span></span></span></span></span></span> 。</p> 

3.2 Weight Sharing

在这里插入图片描述
  权重共享是一种简单而有效的提高参数效率的方法。其核心思想是跨层共享参数,如图2(a)所示从数学上讲,权重共享可以表述为一个Transformer Block f(即一个共享层)的递归更新:
在这里插入图片描述

  其中

     Z
    
    
     i
    
   
  
  
   Z_i
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.83333em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.07153em;">Z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.311664em;"><span class="" style="top: -2.55em; margin-left: -0.07153em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathdefault mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span>为序列在第i层中嵌入的特征,L为层总数,θ为各层间Transformer Block的共享权值。</p> 

  很多研究工作在自然语言Transformer模型中探索并验证了权重共享的有效性。它可以防止参数的数量随着网络深度的增加而增加,而不会严重影响性能,从而提高参数的效率。

4. 方法

4.1 Weight Multiplexing

  权重共享的潜力已经在自然语言处理中得到了证实;然而,其在Vision Transformer中的作用尚不清楚。为了检验这一点,作者直接将跨层权重共享应用于DeiT-S和Swin-B模型,并观察2个问题:

  • 训练不稳定
  • 性能下降
    在这里插入图片描述
      根据作者的实验分析,不同层之间权值的严格一致性是问题的主要原因。其中,权重共享后的梯度ℓ2-范数较大,在不同的Transformer Block内出现波动,如图4所示。
    在这里插入图片描述

  如图5所示,从CKA值可以看出,权重共享后模型生成的Feature map与原模型的相关性较小。为了解决这些问题,受到电信领域多路复用技术的启发,提出了一种新的Transformer Compression技术,Weight Multiplexing。它将多层权重组合为共享部分的单个权重,同时涉及转换和蒸馏,以增加权重的多样性。

  更具体地说,如图2(b)所示,本文提出的权重复用方法包括:

  在多个Transformer Block之间共享权重,可以认为在复用时是一个组合过程;

  在每一层引入转换来模拟解复用;

  运用知识蒸馏增加压缩前后模型之间特征表示的相似性。

  根据Eq.(4),可以将权重复用重新表述如下:
在这里插入图片描述
  其中

     θ
    
    
     i
    
    
     
     
      ′
     
    
   
  
  
   \theta^{'}_i
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 1.20114em; vertical-align: -0.258664em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.02778em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.94248em;"><span class="" style="top: -2.44134em; margin-left: -0.02778em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathdefault mtight">i</span></span></span><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class=""></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.827829em;"><span class="" style="top: -2.931em; margin-right: 0.0714286em;"><span class="pstrut" style="height: 2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.258664em;"><span class=""></span></span></span></span></span></span></span></span></span></span>表示第<span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    i
   
  
  
   i
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.65952em; vertical-align: 0em;"></span><span class="mord mathdefault">i</span></span></span></span></span>个Transformer Layer中Transformer Block的权值。请注意,<span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    
     θ
    
    
     i
    
    
     
     
      ′
     
    
   
  
  
   \theta^{'}_i
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 1.20114em; vertical-align: -0.258664em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.02778em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.94248em;"><span class="" style="top: -2.44134em; margin-left: -0.02778em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathdefault mtight">i</span></span></span><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class=""></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.827829em;"><span class="" style="top: -2.931em; margin-right: 0.0714286em;"><span class="pstrut" style="height: 2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.258664em;"><span class=""></span></span></span></span></span></span></span></span></span></span>中的参数数量远少于<span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    θ
   
  
  
   \theta
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.69444em; vertical-align: 0em;"></span><span class="mord mathdefault" style="margin-right: 0.02778em;">θ</span></span></span></span></span></p> 

4.1.1 Weight Transformation

  Weight Transformation是强加于注意力矩阵和前馈网络。这种转换允许每一层都是不同的,从而提高了参数多样性和模型表示能力。
在这里插入图片描述

  如图所示,各层之间不共享Transformation Kernels的参数,而原始Transformer中除了LayerNorm以外的所有其他层都是共享的。由于共享层占据了模型参数的绝大部分,权重复用后模型大小仅略有增加。

4.1.2 Transformation for MSA

  为了提高参数的多样性,分别在Softmax模块前后插入了2个线性变换。
在这里插入图片描述
  其中,分别为Softmax前后的线性变换。这种线性变换可以使每个注意力矩阵不同,同时结合注意力Head之间的信息来增加参数方差。

4.1.3 Transformation for MLP

  另一方面, 进一步对MLP进行了轻量级转换, 以提高参数的多样性。特别地, 设输入为

    Y
   
   
    =
   
   
    
     [
    
    
     
      y
     
     
      1
     
    
    
     ,
    
    
     …
    
    
     ,
    
    
     
      y
     
     
      d
     
    
    
     ]
    
   
  
  
   Y=\left[y_{1}, \ldots, y_{d}\right]
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord mathdefault" style="margin-right: 0.22222em;">Y</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="minner"><span class="mopen delimcenter" style="top: 0em;">[</span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.03588em;">y</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="minner">…</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.03588em;">y</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">d</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mclose delimcenter" style="top: 0em;">]</span></span></span></span></span></span>, 其中 <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    
     y
    
    
     l
    
   
  
  
   y_{l}
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.625em; vertical-align: -0.19444em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.03588em;">y</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.01968em;">l</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span> 表示所有 Token 的嵌入向量的第1个位置。然后引入线性变换将Y转<br> 换为 <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    
     Y
    
    
     ‘
    
   
   
    =
   
   
    
     [
    
    
     
      C
     
     
      
       (
      
      
       1
      
      
       )
      
     
    
    
     
      y
     
     
      1
     
    
    
     ,
    
    
     …
    
    
     ,
    
    
     
      C
     
     
      
       (
      
      
       d
      
      
       )
      
     
    
    
     )
    
   
   
    
     y
    
    
     d
    
   
   
    ]
   
  
  
   \left.Y^{`}=\left[C^{(1)} y_{1}, \ldots, C^{(d)}\right) y_{d}\right]
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 1.23801em; vertical-align: -0.35001em;"></span><span class="minner"><span class="mopen nulldelimiter"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.22222em;">Y</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.849108em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">‘</span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="minner"><span class="mopen delimcenter" style="top: 0em;"><span class="delimsizing size1">[</span></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.07153em;">C</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.888em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight">1</span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.03588em;">y</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="minner">…</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.07153em;">C</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.888em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mathdefault mtight">d</span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span><span class="mclose delimcenter" style="top: 0em;"><span class="delimsizing size1">)</span></span></span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.03588em;">y</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">d</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mclose delimcenter" style="top: 0em;"><span class="delimsizing size1">]</span></span></span></span></span></span></span>, 其中 <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    
     C
    
    
     
      (
     
     
      1
     
     
      )
     
    
   
   
    ,
   
   
    …
   
   
    ,
   
   
    
     C
    
    
     
      (
     
     
      d
     
     
      )
     
    
   
   
    ∈
   
   
    R
   
   
    N
   
   
    ×
   
   
    N
   
  
  
   C^{(1)}, \ldots, C^{(d)} \in R N \times N
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 1.08244em; vertical-align: -0.19444em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.07153em;">C</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.888em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight">1</span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="minner">…</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.07153em;">C</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.888em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mathdefault mtight">d</span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.76666em; vertical-align: -0.08333em;"></span><span class="mord mathdefault" style="margin-right: 0.00773em;">R</span><span class="mord mathdefault" style="margin-right: 0.10903em;">N</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">×</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord mathdefault" style="margin-right: 0.10903em;">N</span></span></span></span></span> 是线性层的独立权重矩阵。然 后是等式3被重新表述为:<br> <span class="katex--display"><span class="katex-display"><span class="katex"><span class="katex-mathml">
 
  
   
    
     H
    
    
     =
    
    
     σ
    
    
     
      (
     
     
      
       Y
      
      
       ′
      
     
     
      
       W
      
      
       
        (
       
       
        1
       
       
        )
       
      
     
     
      +
     
     
      
       b
      
      
       
        (
       
       
        1
       
       
        )
       
      
     
     
      )
     
    
    
     
      W
     
     
      
       (
      
      
       2
      
      
       )
      
     
    
    
     +
    
    
     
      b
     
     
      
       (
      
      
       2
      
      
       )
      
     
    
   
   
     \mathbf{H}=\sigma\left(\mathbf{Y}^{\prime} \mathbf{W}^{(1)}+\mathbf{b}^{(1)}\right) \mathbf{W}^{(2)}+\mathbf{b}^{(2)} 
   
  
 </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.68611em; vertical-align: 0em;"></span><span class="mord"><span class="mord mathbf">H</span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.80002em; vertical-align: -0.65002em;"></span><span class="mord mathdefault" style="margin-right: 0.03588em;">σ</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="minner"><span class="mopen delimcenter" style="top: 0em;"><span class="delimsizing size2">(</span></span><span class="mord"><span class="mord"><span class="mord mathbf" style="margin-right: 0.02875em;">Y</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.801892em;"><span class="" style="top: -3.113em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span><span class="mord"><span class="mord"><span class="mord mathbf" style="margin-right: 0.01597em;">W</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.938em;"><span class="" style="top: -3.113em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight">1</span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mord"><span class="mord"><span class="mord mathbf">b</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.938em;"><span class="" style="top: -3.113em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight">1</span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span><span class="mclose delimcenter" style="top: 0em;"><span class="delimsizing size2">)</span></span></span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord"><span class="mord"><span class="mord mathbf" style="margin-right: 0.01597em;">W</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.938em;"><span class="" style="top: -3.113em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight">2</span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 0.938em; vertical-align: 0em;"></span><span class="mord"><span class="mord"><span class="mord mathbf">b</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.938em;"><span class="" style="top: -3.113em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mtight">2</span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span></span></span></span></span></span><br>   为了减少参数的数量并在变换中引入局域性, 本文采用 Depth-wise convolutional 来稀疏和 共享每个权重矩阵中的权值, 并且参数量只有 <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    
     K
    
    
     2
    
   
   
    d
   
  
  
   K^{2} d
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.814108em; vertical-align: 0em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.07153em;">K</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span></span><span class="mord mathdefault">d</span></span></span></span></span> 比 <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    
     N
    
    
     2
    
   
   
    d
   
   
    (
   
   
    
     &nbsp;
    
    
     K
    
   
   
    &lt;
   
   
    &lt;
   
   
    N
   
   
    )
   
  
  
   N^{2} d(\mathrm{~K}&lt;&lt;\mathrm{N})
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 1.06411em; vertical-align: -0.25em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.10903em;">N</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span></span><span class="mord mathdefault">d</span><span class="mopen">(</span><span class="mord"><span class="mspace nobreak">&nbsp;</span><span class="mord mathrm">K</span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">&lt;</span></span><span class="base"><span class="strut" style="height: 0.5782em; vertical-align: -0.0391em;"></span><span class="mrel">&lt;</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord"><span class="mord mathrm">N</span></span><span class="mclose">)</span></span></span></span></span> 小得多, 其中 <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    K
   
  
  
   \mathrm{K}
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord"><span class="mord mathrm">K</span></span></span></span></span></span> 是卷积的 kernel size。经过转换后, MLP的输出变得更加多样化, 提高了参数的性能。</p> 

  理论上, 通过这些转换, 权重共享 层可以恢复预训练模型的行为, 类似于 解复用过程 。这样 可以缓解训练不稳定性和性能下降问题, 因为这些问题在原始模型中没有观察到。类似的转换 也被应用于提高没有 权重共享 的 Transformer 的性能, 如 Talking-heads Attention 和 Cei

    T
   
  
  
   T
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord mathdefault" style="margin-right: 0.13889em;">T</span></span></span></span></span> 。</p> 

4.2 Weight Distillation

  为了压缩大型预训练模型并解决 权重共享 导致的性能下降问题, 作者进一步采用 权重蒸馏法 将知识从大型模型转移到小型且紧凑的模型。考虑了 Transformer Block 的3种蒸馏方法:

  • Prediction-Logit Distillation
  • Self-Attention Distillation
  • Hidden-State Distillation

4.2.1 Prediction-Logit Distillation

  Hinton等人证明了深度学习模型可以通过模仿训练过程中表现良好的教师模型的输出行为来获得更好的表现。作者利用这个想法来引入一个预测损失,如下所示:
在这里插入图片描述
  其中,

     Z
    
    
     s
    
   
  
  
   Z_s
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.83333em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.07153em;">Z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span class="" style="top: -2.55em; margin-left: -0.07153em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathdefault mtight">s</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span>和<span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    
     Z
    
    
     t
    
   
  
  
   Z_t
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.83333em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.07153em;">Z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.280556em;"><span class="" style="top: -2.55em; margin-left: -0.07153em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathdefault mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span>分别是学生模型和教师模型预测的对数,T是控制对数平滑度的温度值。在实验中,设置T=1。CE表示交叉熵损失。</p> 

4.2.2 Self-Attention Distillation

  最近有研究表明,利用Transformer Layer中的注意力图来指导学生模型的训练是有效。为了解决学生模型和教师模型之间由于Head num不同所导致的维度不一致问题,受Minilmv2启发,在MSA中对Query、Key和Value之间的关系应用了交叉熵损失。

  首先在所有的 Head 上附加矩阵。例如, 定义

    Q
   
   
    =
   
   
    
     [
    
    
     
      Q
     
     
      1
     
    
    
     ,
    
    
     …
    
    
     ,
    
    
     
      Q
     
     
      M
     
    
    
     ]
    
   
   
    ∈
   
   
    
     R
    
    
     
      N
     
     
      ×
     
     
      M
     
     
      d
     
    
   
  
  
   Q=\left[Q_{1}, \ldots, Q_{M}\right] \in R^{N \times M d}
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.87777em; vertical-align: -0.19444em;"></span><span class="mord mathdefault">Q</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="minner"><span class="mopen delimcenter" style="top: 0em;">[</span><span class="mord"><span class="mord mathdefault">Q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="minner">…</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord"><span class="mord mathdefault">Q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.328331em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.10903em;">M</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mclose delimcenter" style="top: 0em;">]</span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.849108em; vertical-align: 0em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.00773em;">R</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.849108em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.10903em;">N</span><span class="mbin mtight">×</span><span class="mord mathdefault mtight" style="margin-right: 0.10903em;">M</span><span class="mord mathdefault mtight">d</span></span></span></span></span></span></span></span></span></span></span></span></span>, 以同样的方式定 义 <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    K
   
   
    ,
   
   
    V
   
   
    ∈
   
   
    
     R
    
    
     
      N
     
     
      ×
     
     
      M
     
     
      d
     
    
   
  
  
   K, V \in R^{N \times M d}
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.87777em; vertical-align: -0.19444em;"></span><span class="mord mathdefault" style="margin-right: 0.07153em;">K</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathdefault" style="margin-right: 0.22222em;">V</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.849108em; vertical-align: 0em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.00773em;">R</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.849108em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.10903em;">N</span><span class="mbin mtight">×</span><span class="mord mathdefault mtight" style="margin-right: 0.10903em;">M</span><span class="mord mathdefault mtight">d</span></span></span></span></span></span></span></span></span></span></span></span></span> 。为了简化符号, 分别用 <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    
     S
    
    
     1
    
   
   
    、
   
   
    
     S
    
    
     2
    
   
  
  
   S_{1} 、 S_{2}
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.83333em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.05764em;">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mord cjk_fallback">、</span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.05764em;">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span> 和 <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    
     S
    
    
     3
    
   
  
  
   S_{3}
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.83333em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.05764em;">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">3</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span> 来分别表示 <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    Q
   
   
    、
   
   
    
     &nbsp;
    
    
     K
    
   
  
  
   \mathrm{Q} 、 \mathrm{~K}
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.87777em; vertical-align: -0.19444em;"></span><span class="mord"><span class="mord mathrm">Q</span></span><span class="mord cjk_fallback">、</span><span class="mord"><span class="mspace nobreak">&nbsp;</span><span class="mord mathrm">K</span></span></span></span></span></span> 和 <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    v
   
  
  
   \mathrm{v}
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord"><span class="mord mathrm" style="margin-right: 0.01389em;">v</span></span></span></span></span></span> 。然后, 可以 生成由 <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    
     R
    
    
     
      i
     
     
      j
     
    
   
   
    =
   
   
    softmax
   
   
    ⁡
   
   
    
     (
    
    
     
      S
     
     
      i
     
    
    
     
      S
     
     
      j
     
     
      T
     
    
    
     /
    
    
     
      
       M
      
      
       d
      
     
    
    
     )
    
   
  
  
   R_{i j}=\operatorname{softmax}\left(S_{i} S_{j}^{T} / \sqrt{M d}\right)
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.969438em; vertical-align: -0.286108em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.00773em;">R</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.311664em;"><span class="" style="top: -2.55em; margin-left: -0.00773em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">i</span><span class="mord mathdefault mtight" style="margin-right: 0.05724em;">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.80002em; vertical-align: -0.65002em;"></span><span class="mop"><span class="mord mathrm">s</span><span class="mord mathrm">o</span><span class="mord mathrm" style="margin-right: 0.07778em;">f</span><span class="mord mathrm">t</span><span class="mord mathrm">m</span><span class="mord mathrm">a</span><span class="mord mathrm">x</span></span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="minner"><span class="mopen delimcenter" style="top: 0em;"><span class="delimsizing size2">(</span></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.05764em;">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.311664em;"><span class="" style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.05764em;">S</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.841331em;"><span class="" style="top: -2.44134em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.05724em;">j</span></span></span></span><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.13889em;">T</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.394772em;"><span class=""></span></span></span></span></span></span><span class="mord">/</span><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.93222em;"><span class="svg-align" style="top: -3em;"><span class="pstrut" style="height: 3em;"></span><span class="mord" style="padding-left: 0.833em;"><span class="mord mathdefault" style="margin-right: 0.10903em;">M</span><span class="mord mathdefault">d</span></span></span><span class="" style="top: -2.89222em;"><span class="pstrut" style="height: 3em;"></span><span class="hide-tail" style="min-width: 0.853em; height: 1.08em;">
        <svg width="400em" height="1.08em" viewBox="0 0 400000 1080" preserveAspectRatio="xMinYMin slice">
         <path d="M95,702c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,

-10,-9.5,-14c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54c44.2,-33.3,65.8,
-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10s173,378,173,378c0.7,0,
35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429c69,-144,104.5,-217.7,106.5,
-221c5.3,-9.3,12,-14,20,-14H400000v40H845.2724s-225.272,467,-225.272,467
s-235,486,-235,486c-2.7,4.7,-9,7,-19,7c-6,0,-10,-1,-12,-3s-194,-422,-194,-422
s-65,47,-65,47z M834 80H400000v40H845z">
) 定义的 9 个不同的关系矩阵。注意,

     R
    
    
     12
    
   
  
  
   R_{12}
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.83333em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.00773em;">R</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: -0.00773em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span><span class="mord mtight">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span> 是注意力矩阵 a, Self-Attention Distillation 损失可以表示为:<br> <img src="https://img-blog.csdnimg.cn/ad01853acee647778f05002f4fbcb754.png" alt="在这里插入图片描述"></p> 

  其中,

     R
    
    
     
      i
     
     
      j
     
     
      ,
     
     
      n
     
    
   
  
  
   R_{i j, n}
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.969438em; vertical-align: -0.286108em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.00773em;">R</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.311664em;"><span class="" style="top: -2.55em; margin-left: -0.00773em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">i</span><span class="mord mathdefault mtight" style="margin-right: 0.05724em;">j</span><span class="mpunct mtight">,</span><span class="mord mathdefault mtight">n</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span class=""></span></span></span></span></span></span></span></span></span></span> 表示 <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    
     R
    
    
     
      i
     
     
      j
     
    
   
  
  
   R_{i j}
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.969438em; vertical-align: -0.286108em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.00773em;">R</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.311664em;"><span class="" style="top: -2.55em; margin-left: -0.00773em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">i</span><span class="mord mathdefault mtight" style="margin-right: 0.05724em;">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span class=""></span></span></span></span></span></span></span></span></span></span> 的第 <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    n
   
  
  
   \mathrm{n}
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord"><span class="mord mathrm">n</span></span></span></span></span></span> 行。</p> 

4.2.3 Hidden-State Distillation

  类似地, 可以生成隐藏状态的关系矩阵, 即由MLP输出的特征。用

    H
   
   
    ∈
   
   
    
     R
    
    
     
      N
     
     
      ×
     
     
      d
     
    
   
  
  
   H \in R^{N \times d}
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.72243em; vertical-align: -0.0391em;"></span><span class="mord mathdefault" style="margin-right: 0.08125em;">H</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.849108em; vertical-align: 0em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.00773em;">R</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.849108em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.10903em;">N</span><span class="mbin mtight">×</span><span class="mord mathdefault mtight">d</span></span></span></span></span></span></span></span></span></span></span></span></span> 表示 Transforme <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    r
   
  
  
   r
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathdefault" style="margin-right: 0.02778em;">r</span></span></span></span></span> Layer 的隐藏状态, 将基于关系矩阵的隐藏状态蒸馏损失定义为:<br> <span class="katex--display"><span class="katex-display"><span class="katex"><span class="katex-mathml">
 
  
   
    
     
      L
     
     
      
       h
      
      
       d
      
      
       d
      
      
       n
      
     
    
    
     =
    
    
     
      1
     
     
      N
     
    
    
     
      ∑
     
     
      
       n
      
      
       =
      
      
       1
      
     
     
      N
     
    
    
     C
    
    
     E
    
    
     
      (
     
     
      
       R
      
      
       
        H
       
       
        ,
       
       
        n
       
      
      
       s
      
     
     
      ,
     
     
      
       R
      
      
       
        H
       
       
        ,
       
       
        n
       
      
      
       t
      
     
     
      )
     
    
   
   
     \mathcal{L}_{h d d n}=\frac{1}{N} \sum_{n=1}^{N} C E\left(\mathbf{R}_{H, n}^{s}, \mathbf{R}_{H, n}^{t}\right) 
   
  
 </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.83333em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord"><span class="mord mathcal">L</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">h</span><span class="mord mathdefault mtight">d</span><span class="mord mathdefault mtight">d</span><span class="mord mathdefault mtight">n</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 3.09545em; vertical-align: -1.26711em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.32144em;"><span class="" style="top: -2.314em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.10903em;">N</span></span></span><span class="" style="top: -3.23em;"><span class="pstrut" style="height: 3em;"></span><span class="frac-line" style="border-bottom-width: 0.04em;"></span></span><span class="" style="top: -3.677em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.686em;"><span class=""></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.82834em;"><span class="" style="top: -1.88289em; margin-left: 0em;"><span class="pstrut" style="height: 3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">n</span><span class="mrel mtight">=</span><span class="mord mtight">1</span></span></span></span><span class="" style="top: -3.05001em;"><span class="pstrut" style="height: 3.05em;"></span><span class=""><span class="mop op-symbol large-op">∑</span></span></span><span class="" style="top: -4.30001em; margin-left: 0em;"><span class="pstrut" style="height: 3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.10903em;">N</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 1.26711em;"><span class=""></span></span></span></span></span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathdefault" style="margin-right: 0.07153em;">C</span><span class="mord mathdefault" style="margin-right: 0.05764em;">E</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="minner"><span class="mopen delimcenter" style="top: 0em;"><span class="delimsizing size1">(</span></span><span class="mord"><span class="mord"><span class="mord mathbf">R</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.714392em;"><span class="" style="top: -2.453em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.08125em;">H</span><span class="mpunct mtight">,</span><span class="mord mathdefault mtight">n</span></span></span></span><span class="" style="top: -3.113em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">s</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.383108em;"><span class=""></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord"><span class="mord"><span class="mord mathbf">R</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.843556em;"><span class="" style="top: -2.453em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.08125em;">H</span><span class="mpunct mtight">,</span><span class="mord mathdefault mtight">n</span></span></span></span><span class="" style="top: -3.113em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.383108em;"><span class=""></span></span></span></span></span></span><span class="mclose delimcenter" style="top: 0em;"><span class="delimsizing size1">)</span></span></span></span></span></span></span></span><br>   其中 <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    
     R
    
    
     
      H
     
     
      ,
     
     
      n
     
    
   
  
  
   R_{H, n}
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.969438em; vertical-align: -0.286108em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.00773em;">R</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.328331em;"><span class="" style="top: -2.55em; margin-left: -0.00773em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.08125em;">H</span><span class="mpunct mtight">,</span><span class="mord mathdefault mtight">n</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span class=""></span></span></span></span></span></span></span></span></span></span> 表示第 <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    n
   
  
  
   \mathrm{n}
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord"><span class="mord mathrm">n</span></span></span></span></span></span> 行 <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    
     R
    
    
     H
    
   
  
  
   R_{H}
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.83333em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.00773em;">R</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.328331em;"><span class="" style="top: -2.55em; margin-left: -0.00773em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.08125em;">H</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span>, 计算公式为 <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    
     R
    
    
     H
    
   
   
    =
   
   
    softmax
   
   
    ⁡
   
   
    
     (
    
    
     H
    
    
     
      H
     
     
      T
     
    
    
     /
    
    
     
      d
     
    
    
     )
    
   
  
  
   R_{H}=\operatorname{softmax}\left(H H^{T} / \sqrt{d}\right)
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.83333em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.00773em;">R</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.328331em;"><span class="" style="top: -2.55em; margin-left: -0.00773em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.08125em;">H</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.80002em; vertical-align: -0.65002em;"></span><span class="mop"><span class="mord mathrm">s</span><span class="mord mathrm">o</span><span class="mord mathrm" style="margin-right: 0.07778em;">f</span><span class="mord mathrm">t</span><span class="mord mathrm">m</span><span class="mord mathrm">a</span><span class="mord mathrm">x</span></span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="minner"><span class="mopen delimcenter" style="top: 0em;"><span class="delimsizing size2">(</span></span><span class="mord mathdefault" style="margin-right: 0.08125em;">H</span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.08125em;">H</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.841331em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.13889em;">T</span></span></span></span></span></span></span></span></span><span class="mord">/</span><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.93222em;"><span class="svg-align" style="top: -3em;"><span class="pstrut" style="height: 3em;"></span><span class="mord" style="padding-left: 0.833em;"><span class="mord mathdefault">d</span></span></span><span class="" style="top: -2.89222em;"><span class="pstrut" style="height: 3em;"></span><span class="hide-tail" style="min-width: 0.853em; height: 1.08em;">
        <svg width="400em" height="1.08em" viewBox="0 0 400000 1080" preserveAspectRatio="xMinYMin slice">
         <path d="M95,702c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,

-10,-9.5,-14c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54c44.2,-33.3,65.8,
-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10s173,378,173,378c0.7,0,
35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429c69,-144,104.5,-217.7,106.5,
-221c5.3,-9.3,12,-14,20,-14H400000v40H845.2724s-225.272,467,-225.272,467
s-235,486,-235,486c-2.7,4.7,-9,7,-19,7c-6,0,-10,-1,-12,-3s-194,-422,-194,-422
s-65,47,-65,47z M834 80H400000v40H845z">
)

4.2.4 总的蒸馏损失

  根据作者的观察, 仅使用 prediction soft labels 比同时使用 prediction soft labels

    +
   
  
  
   +
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.66666em; vertical-align: -0.08333em;"></span><span class="mord">+</span></span></span></span></span> gr ound truth labels 能够产生更好的性能, 因此最终蒸馏目标函数表示为:<br> <span class="katex--display"><span class="katex-display"><span class="katex"><span class="katex-mathml">
 
  
   
    
     
      L
     
     
      train&nbsp;
     
    
    
     =
    
    
     
      L
     
     
      pred&nbsp;
     
    
    
     +
    
    
     β
    
    
     
      L
     
     
      
       a
      
      
       t
      
      
       t
      
      
       n
      
     
    
    
     +
    
    
     γ
    
    
     
      L
     
     
      
       h
      
      
       d
      
      
       d
      
      
       n
      
     
    
    
     ,
    
   
   
     \mathcal{L}_{\text {train }}=\mathcal{L}_{\text {pred }}+\beta \mathcal{L}_{a t t n}+\gamma \mathcal{L}_{h d d n}, 
   
  
 </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.83333em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord"><span class="mord mathcal">L</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.317502em;"><span class="" style="top: -2.55em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">train&nbsp;</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.969438em; vertical-align: -0.286108em;"></span><span class="mord"><span class="mord"><span class="mord mathcal">L</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">pred&nbsp;</span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 0.88888em; vertical-align: -0.19444em;"></span><span class="mord mathdefault" style="margin-right: 0.05278em;">β</span><span class="mord"><span class="mord"><span class="mord mathcal">L</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.280556em;"><span class="" style="top: -2.55em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">a</span><span class="mord mathdefault mtight">t</span><span class="mord mathdefault mtight">t</span><span class="mord mathdefault mtight">n</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 0.87777em; vertical-align: -0.19444em;"></span><span class="mord mathdefault" style="margin-right: 0.05556em;">γ</span><span class="mord"><span class="mord"><span class="mord mathcal">L</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">h</span><span class="mord mathdefault mtight">d</span><span class="mord mathdefault mtight">d</span><span class="mord mathdefault mtight">n</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mpunct">,</span></span></span></span></span></span><br>   其中, <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    β
   
  
  
   \beta
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.88888em; vertical-align: -0.19444em;"></span><span class="mord mathdefault" style="margin-right: 0.05278em;">β</span></span></span></span></span> 和 <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    γ
   
  
  
   \gamma
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.625em; vertical-align: -0.19444em;"></span><span class="mord mathdefault" style="margin-right: 0.05556em;">γ</span></span></span></span></span> 分别为超参数, 默认值分别为 1 和 <span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    0.1
   
  
  
   0.1
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">0</span><span class="mord">.</span><span class="mord">1</span></span></span></span></span> 。</p> 

4.2 Compression Pipeline

阶段1:通过权重转换生成紧凑的架构

  给定一个预训练的Vision Transformer模型,首先在共享每个相邻Transformer Layer中除了LayerNorm以外的参数。然后通过在Softmax层之前和之后插入一个线性层对每一层进行权值变换。此外,还为MLP引入了一个Depth-wise Convolutional。这些线性层和转换块的参数不共享。

阶段2:用Weight Distillation训练压缩后的模型

  在这一步中使用所提出的Weight Distillation方法,将知识从大的预训练模型转移到小的模型。Transformer Block内部的这种精馏使得学生网络能够再现教师网络的行为,从而从大规模的预训练模型中提取出更多有用的知识。

  请注意,只有当教师模型和学生模型都是Transformer架构时,才会执行此操作。在其他情况下,学生和教师的结构是异质的,只保留Prediction-Logit Distillation。

5. 实验

5.1 消融实验

5.1.1 Weight Sharing

在这里插入图片描述
  如图4所示,DeiT-S和Swin-B 权重共享后梯度的ℓ2-范数较大,说明权值大小变化较快。此外,权重共享还导致了不同层间的梯度范数的波动。这可能会导致不同的层的优化空间。特别是,一些层更新得很快,而其他部分几乎没有优化,这使得模型很可能收敛到一个糟糕的局部最优,甚至在训练中出现分歧。

  因此,不同层间共享的严格相同的权重会导致训练的不稳定性。然而,权值复用方法可以通过引入变换来提高参数多样性,从而降低梯度范数和提高层间的平滑度,促进更稳定的训练过程。
在这里插入图片描述
  在性能分析方面,在图5中对权值共享和权重复用之间的特征相似性与CKA进行了比较。CKA值越高,表示两个模型之间的特征表示越相似,从而获得相似的性能。可以观察到,DeiT和Swin在应用权重共享后都存在很大的特征表示偏差,特别是在最后几层,这可能是权重共享导致性能下降的原因之一。然而提出的权值复用方法可以提高相似性。

5.1.2 Component-wise

在这里插入图片描述

5.1.3 Number of Sharing Blocks

在这里插入图片描述
  在表2中,在Swin-T或DeiT-B的每个阶段共享2个Block可以显著将参数量从28M减少到16M,86M减少到44M,而Top-1的精度提高了1%。在极端情况下,每个阶段的所有块都共享,Mini-Swin-T仍然可以以43%的参数优于原始模型。Mini-DeiT-B可以实现90%的参数降低,性能仅下降2%。

  结果表明,Vision Transformer存在参数冗余,所提出的MiniViT可以有效地提高参数效率。此外,MiniViT是可配置的,以满足模型大小和性能的各种要求。

5.1.4 Distillation Losses

在这里插入图片描述
  如表3所示,与仅使用预测损失相比,额外的GT标签导致Swin的性能下降了0.3%,这是由于权值共享导致学习能力下降。此外,还观察到,在应用Self-Attention Distillation和Hidden-State Distillation后,精确度提高了约0.2%,表明提出的Distillation方法是有效性的。

5.2 SoTA分类

在这里插入图片描述

5.3 迁移学习

在这里插入图片描述

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值