【nn.Parameter】Pytorch特征融合自适应权重设置(可学习权重使用)

2021年11月17日11:32:14
今天我们来完成Pytorch自适应可学习权重系数,在进行特征融合时,给不同特征图分配可学习的权重!


原文:基于自适应特征融合与转换的小样本图像分类(2021)
期刊:计算机工程与应用(中文核心、CSCD扩展版)

实现这篇论文里面多特征融合的分支!


实现自适应特征处理模块如下图所示:
在这里插入图片描述
特征融合公式如下:

          F
         
         
          
           f
          
          
           f
          
         
        
        
         =
        
        
         
          α
         
         
          1
         
        
        
         ∗
        
        
         
          F
         
         
          
           i
          
          
           d
          
         
        
        
         +
        
        
         
          α
         
         
          2
         
        
        
         ∗
        
        
         
          F
         
         
          
           d
          
          
           c
          
          
           o
          
          
           n
          
          
           v
          
         
        
        
         +
        
        
         
          α
         
         
          3
         
        
        
         ∗
        
        
         
          F
         
         
          max
         
         
          ⁡
         
        
        
         +
        
        
         
          α
         
         
          4
         
        
        
         ∗
        
        
         
          F
         
         
          
           a
          
          
           v
          
          
           g
          
         
        
       
      
     
    
    
     
      
       
      
     
     
      
       
        
        
         
          a
         
         
          i
         
        
        
         =
        
        
         
          
           e
          
          
           
            w
           
           
            i
           
          
         
         
          
           
            Σ
           
           
            j
           
          
          
           
            e
           
           
            
             w
            
            
             j
            
           
          
         
        
        
         (
        
        
         i
        
        
         =
        
        
         1
        
        
         ,
        
        
         2
        
        
         ,
        
        
         3
        
        
         ,
        
        
         4
        
        
         ;
        
        
         j
        
        
         =
        
        
         1
        
        
         ,
        
        
         2
        
        
         ,
        
        
         3
        
        
         ,
        
        
         4
        
        
         )
        
       
      
     
    
   
   
     <span class="MathJax_Preview" style="color: inherit; display: none;"></span><div class="MathJax_Display"><span class="MathJax MathJax_FullWidth" id="MathJax-Element-1-Frame" tabindex="0" data-mathml="<math xmlns=&quot;http://www.w3.org/1998/Math/MathML&quot; display=&quot;block&quot;><mtable columnalign=&quot;right left right left right left right left right left right left&quot; rowspacing=&quot;3pt&quot; columnspacing=&quot;0em 2em 0em 2em 0em 2em 0em 2em 0em 2em 0em&quot; displaystyle=&quot;true&quot;><mtr><mtd /><mtd><msub><mi>F</mi><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mi>f</mi><mi>f</mi></mrow></msub><mo>=</mo><msub><mi>&amp;#x03B1;</mi><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mn>1</mn></mrow></msub><mo>&amp;#x2217;</mo><msub><mi>F</mi><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mi>i</mi><mi>d</mi></mrow></msub><mo>+</mo><msub><mi>&amp;#x03B1;</mi><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mn>2</mn></mrow></msub><mo>&amp;#x2217;</mo><msub><mi>F</mi><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mi>d</mi><mi>c</mi><mi>o</mi><mi>n</mi><mi>v</mi></mrow></msub><mo>+</mo><msub><mi>&amp;#x03B1;</mi><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mn>3</mn></mrow></msub><mo>&amp;#x2217;</mo><msub><mi>F</mi><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mo movablelimits=&quot;true&quot; form=&quot;prefix&quot;>max</mo></mrow></msub><mo>+</mo><msub><mi>&amp;#x03B1;</mi><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mn>4</mn></mrow></msub><mo>&amp;#x2217;</mo><msub><mi>F</mi><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mi>a</mi><mi>v</mi><mi>g</mi></mrow></msub></mtd></mtr><mtr><mtd /><mtd><msub><mi>a</mi><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mi>i</mi></mrow></msub><mo>=</mo><mfrac><msup><mi>e</mi><mrow class=&quot;MJX-TeXAtom-ORD&quot;><msub><mi>w</mi><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mi>i</mi></mrow></msub></mrow></msup><mrow><msub><mi mathvariant=&quot;normal&quot;>&amp;#x03A3;</mi><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mi>j</mi></mrow></msub><msup><mi>e</mi><mrow class=&quot;MJX-TeXAtom-ORD&quot;><msub><mi>w</mi><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mi>j</mi></mrow></msub></mrow></msup></mrow></mfrac><mo stretchy=&quot;false&quot;>(</mo><mi>i</mi><mo>=</mo><mn>1</mn><mo>,</mo><mn>2</mn><mo>,</mo><mn>3</mn><mo>,</mo><mn>4</mn><mo>;</mo><mi>j</mi><mo>=</mo><mn>1</mn><mo>,</mo><mn>2</mn><mo>,</mo><mn>3</mn><mo>,</mo><mn>4</mn><mo stretchy=&quot;false&quot;>)</mo></mtd></mtr></mtable></math>" role="presentation" style="position: relative;"><nobr aria-hidden="true"><span class="math" id="MathJax-Span-1" style="width: 100%; display: inline-block; min-width: 22.53em;"><span style="display: inline-block; position: relative; width: 100%; height: 0px; font-size: 102%;"><span style="position: absolute; clip: rect(2.432em, 1022.07em, 6.482em, -999.997em); top: -4.705em; left: 0em; width: 100%;"><span class="mrow" id="MathJax-Span-2"><span style="display: inline-block; position: relative; width: 100%; height: 0px;"><span style="position: absolute; clip: rect(2.432em, 1022.07em, 6.482em, -999.997em); top: -4.705em; left: 50%; margin-left: -11.033em;"><span class="mtable" id="MathJax-Span-3" style="padding-left: 0.154em;"><span style="display: inline-block; position: relative; width: 21.922em; height: 0px;"><span style="position: absolute; clip: rect(2.534em, 1000em, 4.812em, -999.997em); top: -3.997em; left: 0em;"><span style="display: inline-block; position: relative; width: 0em; height: 0px;"><span style="position: absolute; clip: rect(3.85em, 1000em, 4.154em, -999.997em); top: -5.313em; right: 0em;"><span class="mtd" id="MathJax-Span-4"><span class="mrow" id="MathJax-Span-5"></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; clip: rect(3.85em, 1000em, 4.154em, -999.997em); top: -3.339em; right: 0em;"><span class="mtd" id="MathJax-Span-69"><span class="mrow" id="MathJax-Span-70"></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; clip: rect(2.483em, 1021.92em, 6.381em, -999.997em); top: -4.604em; left: 0em;"><span style="display: inline-block; position: relative; width: 21.922em; height: 0px;"><span style="position: absolute; width: 100%; clip: rect(3.192em, 1021.92em, 4.457em, -999.997em); top: -5.313em; left: 0em;"><span class="mtd" id="MathJax-Span-6"><span class="mrow" id="MathJax-Span-7"><span style="display: inline-block; position: relative; width: 100%; height: 0px;"><span style="position: absolute; clip: rect(3.192em, 1021.92em, 4.457em, -999.997em); top: -3.997em; left: 50%; margin-left: -10.932em;"><span class="msubsup" id="MathJax-Span-8"><span style="display: inline-block; position: relative; width: 1.521em; height: 0px;"><span style="position: absolute; clip: rect(3.192em, 1000.76em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-9" style="font-family: MathJax_Math-italic;">F<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.104em;"></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.845em; left: 0.661em;"><span class="texatom" id="MathJax-Span-10"><span class="mrow" id="MathJax-Span-11"><span style="display: inline-block; position: relative; width: 0.762em; height: 0px;"><span style="position: absolute; clip: rect(3.344em, 1000.76em, 4.305em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-12" style="font-size: 70.7%; font-family: MathJax_Math-italic;">f<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.053em;"></span></span><span class="mi" id="MathJax-Span-13" style="font-size: 70.7%; font-family: MathJax_Math-italic;">f<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.053em;"></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span class="mo" id="MathJax-Span-14" style="font-family: MathJax_Main; padding-left: 0.256em;">=</span><span class="msubsup" id="MathJax-Span-15" style="padding-left: 0.256em;"><span style="display: inline-block; position: relative; width: 1.066em; height: 0px;"><span style="position: absolute; clip: rect(3.394em, 1000.61em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-16" style="font-family: MathJax_Math-italic;">α</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.845em; left: 0.661em;"><span class="texatom" id="MathJax-Span-17"><span class="mrow" id="MathJax-Span-18"><span style="display: inline-block; position: relative; width: 0.357em; height: 0px;"><span style="position: absolute; clip: rect(3.394em, 1000.31em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mn" id="MathJax-Span-19" style="font-size: 70.7%; font-family: MathJax_Main;">1</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span class="mo" id="MathJax-Span-20" style="font-family: MathJax_Main; padding-left: 0.205em;">∗</span><span class="msubsup" id="MathJax-Span-21" style="padding-left: 0.205em;"><span style="display: inline-block; position: relative; width: 1.319em; height: 0px;"><span style="position: absolute; clip: rect(3.192em, 1000.76em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-22" style="font-family: MathJax_Math-italic;">F<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.104em;"></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.845em; left: 0.661em;"><span class="texatom" id="MathJax-Span-23"><span class="mrow" id="MathJax-Span-24"><span style="display: inline-block; position: relative; width: 0.61em; height: 0px;"><span style="position: absolute; clip: rect(3.344em, 1000.61em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-25" style="font-size: 70.7%; font-family: MathJax_Math-italic;">i</span><span class="mi" id="MathJax-Span-26" style="font-size: 70.7%; font-family: MathJax_Math-italic;">d<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span class="mo" id="MathJax-Span-27" style="font-family: MathJax_Main; padding-left: 0.205em;">+</span><span class="msubsup" id="MathJax-Span-28" style="padding-left: 0.205em;"><span style="display: inline-block; position: relative; width: 1.066em; height: 0px;"><span style="position: absolute; clip: rect(3.394em, 1000.61em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-29" style="font-family: MathJax_Math-italic;">α</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.845em; left: 0.661em;"><span class="texatom" id="MathJax-Span-30"><span class="mrow" id="MathJax-Span-31"><span style="display: inline-block; position: relative; width: 0.357em; height: 0px;"><span style="position: absolute; clip: rect(3.394em, 1000.31em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mn" id="MathJax-Span-32" style="font-size: 70.7%; font-family: MathJax_Main;">2</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span class="mo" id="MathJax-Span-33" style="font-family: MathJax_Main; padding-left: 0.205em;">∗</span><span class="msubsup" id="MathJax-Span-34" style="padding-left: 0.205em;"><span style="display: inline-block; position: relative; width: 2.483em; height: 0px;"><span style="position: absolute; clip: rect(3.192em, 1000.76em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-35" style="font-family: MathJax_Math-italic;">F<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.104em;"></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.845em; left: 0.661em;"><span class="texatom" id="MathJax-Span-36"><span class="mrow" id="MathJax-Span-37"><span style="display: inline-block; position: relative; width: 1.774em; height: 0px;"><span style="position: absolute; clip: rect(3.344em, 1001.77em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-38" style="font-size: 70.7%; font-family: MathJax_Math-italic;">d<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span class="mi" id="MathJax-Span-39" style="font-size: 70.7%; font-family: MathJax_Math-italic;">c</span><span class="mi" id="MathJax-Span-40" style="font-size: 70.7%; font-family: MathJax_Math-italic;">o</span><span class="mi" id="MathJax-Span-41" style="font-size: 70.7%; font-family: MathJax_Math-italic;">n</span><span class="mi" id="MathJax-Span-42" style="font-size: 70.7%; font-family: MathJax_Math-italic;">v</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span class="mo" id="MathJax-Span-43" style="font-family: MathJax_Main; padding-left: 0.205em;">+</span><span class="msubsup" id="MathJax-Span-44" style="padding-left: 0.205em;"><span style="display: inline-block; position: relative; width: 1.066em; height: 0px;"><span style="position: absolute; clip: rect(3.394em, 1000.61em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-45" style="font-family: MathJax_Math-italic;">α</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.845em; left: 0.661em;"><span class="texatom" id="MathJax-Span-46"><span class="mrow" id="MathJax-Span-47"><span style="display: inline-block; position: relative; width: 0.357em; height: 0px;"><span style="position: absolute; clip: rect(3.394em, 1000.31em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mn" id="MathJax-Span-48" style="font-size: 70.7%; font-family: MathJax_Main;">3</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span class="mo" id="MathJax-Span-49" style="font-family: MathJax_Main; padding-left: 0.205em;">∗</span><span class="msubsup" id="MathJax-Span-50" style="padding-left: 0.205em;"><span style="display: inline-block; position: relative; width: 2.027em; height: 0px;"><span style="position: absolute; clip: rect(3.192em, 1000.76em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-51" style="font-family: MathJax_Math-italic;">F<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.104em;"></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.845em; left: 0.661em;"><span class="texatom" id="MathJax-Span-52"><span class="mrow" id="MathJax-Span-53"><span style="display: inline-block; position: relative; width: 1.319em; height: 0px;"><span style="position: absolute; clip: rect(3.546em, 1001.32em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mo" id="MathJax-Span-54" style="font-size: 70.7%; font-family: MathJax_Main;">max</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span class="mo" id="MathJax-Span-55" style="font-family: MathJax_Main; padding-left: 0.205em;">+</span><span class="msubsup" id="MathJax-Span-56" style="padding-left: 0.205em;"><span style="display: inline-block; position: relative; width: 1.066em; height: 0px;"><span style="position: absolute; clip: rect(3.394em, 1000.61em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-57" style="font-family: MathJax_Math-italic;">α</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.845em; left: 0.661em;"><span class="texatom" id="MathJax-Span-58"><span class="mrow" id="MathJax-Span-59"><span style="display: inline-block; position: relative; width: 0.357em; height: 0px;"><span style="position: absolute; clip: rect(3.394em, 1000.36em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mn" id="MathJax-Span-60" style="font-size: 70.7%; font-family: MathJax_Main;">4</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span class="mo" id="MathJax-Span-61" style="font-family: MathJax_Main; padding-left: 0.205em;">∗</span><span class="msubsup" id="MathJax-Span-62" style="padding-left: 0.205em;"><span style="display: inline-block; position: relative; width: 1.774em; height: 0px;"><span style="position: absolute; clip: rect(3.192em, 1000.76em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-63" style="font-family: MathJax_Math-italic;">F<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.104em;"></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.845em; left: 0.661em;"><span class="texatom" id="MathJax-Span-64"><span class="mrow" id="MathJax-Span-65"><span style="display: inline-block; position: relative; width: 1.066em; height: 0px;"><span style="position: absolute; clip: rect(3.546em, 1001.07em, 4.305em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-66" style="font-size: 70.7%; font-family: MathJax_Math-italic;">a</span><span class="mi" id="MathJax-Span-67" style="font-size: 70.7%; font-family: MathJax_Math-italic;">v</span><span class="mi" id="MathJax-Span-68" style="font-size: 70.7%; font-family: MathJax_Math-italic;">g<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; width: 100%; clip: rect(2.483em, 1015.95em, 5.115em, -999.997em); top: -3.339em; left: 0em;"><span class="mtd" id="MathJax-Span-71"><span class="mrow" id="MathJax-Span-72"><span style="display: inline-block; position: relative; width: 100%; height: 0px;"><span style="position: absolute; clip: rect(2.483em, 1015.95em, 5.115em, -999.997em); top: -3.997em; left: 50%; margin-left: -7.996em;"><span class="msubsup" id="MathJax-Span-73"><span style="display: inline-block; position: relative; width: 0.863em; height: 0px;"><span style="position: absolute; clip: rect(3.394em, 1000.51em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-74" style="font-family: MathJax_Math-italic;">a</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.845em; left: 0.509em;"><span class="texatom" id="MathJax-Span-75"><span class="mrow" id="MathJax-Span-76"><span style="display: inline-block; position: relative; width: 0.256em; height: 0px;"><span style="position: absolute; clip: rect(3.394em, 1000.21em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-77" style="font-size: 70.7%; font-family: MathJax_Math-italic;">i</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span class="mo" id="MathJax-Span-78" style="font-family: MathJax_Main; padding-left: 0.256em;">=</span><span class="mfrac" id="MathJax-Span-79" style="padding-left: 0.256em;"><span style="display: inline-block; position: relative; width: 2.534em; height: 0px; margin-right: 0.104em; margin-left: 0.104em;"><span style="position: absolute; clip: rect(3.192em, 1001.27em, 4.154em, -999.997em); top: -4.655em; left: 50%; margin-left: -0.656em;"><span class="msubsup" id="MathJax-Span-80"><span style="display: inline-block; position: relative; width: 1.268em; height: 0px;"><span style="position: absolute; clip: rect(3.394em, 1000.41em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-81" style="font-family: MathJax_Math-italic;">e</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -4.351em; left: 0.458em;"><span class="texatom" id="MathJax-Span-82"><span class="mrow" id="MathJax-Span-83"><span style="display: inline-block; position: relative; width: 0.711em; height: 0px;"><span style="position: absolute; clip: rect(3.546em, 1000.71em, 4.255em, -999.997em); top: -3.997em; left: 0em;"><span class="msubsup" id="MathJax-Span-84"><span style="display: inline-block; position: relative; width: 0.711em; height: 0px;"><span style="position: absolute; clip: rect(3.546em, 1000.51em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-85" style="font-size: 70.7%; font-family: MathJax_Math-italic;">w</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.895em; left: 0.509em;"><span class="texatom" id="MathJax-Span-86"><span class="mrow" id="MathJax-Span-87"><span style="display: inline-block; position: relative; width: 0.154em; height: 0px;"><span style="position: absolute; clip: rect(3.495em, 1000.15em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-88" style="font-size: 50%; font-family: MathJax_Math-italic;">i</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; clip: rect(3.192em, 1002.38em, 4.457em, -999.997em); top: -3.288em; left: 50%; margin-left: -1.212em;"><span class="mrow" id="MathJax-Span-89"><span style="display: inline-block; position: relative; width: 2.382em; height: 0px;"><span style="position: absolute; clip: rect(3.192em, 1002.38em, 4.457em, -999.997em); top: -3.997em; left: 0em;"><span class="msubsup" id="MathJax-Span-90"><span style="display: inline-block; position: relative; width: 1.066em; height: 0px;"><span style="position: absolute; clip: rect(3.192em, 1000.66em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-91" style="font-family: MathJax_Main;">Σ</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.845em; left: 0.711em;"><span class="texatom" id="MathJax-Span-92"><span class="mrow" id="MathJax-Span-93"><span style="display: inline-block; position: relative; width: 0.306em; height: 0px;"><span style="position: absolute; clip: rect(3.394em, 1000.31em, 4.305em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-94" style="font-size: 70.7%; font-family: MathJax_Math-italic;">j</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span class="msubsup" id="MathJax-Span-95"><span style="display: inline-block; position: relative; width: 1.319em; height: 0px;"><span style="position: absolute; clip: rect(3.394em, 1000.41em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-96" style="font-family: MathJax_Math-italic;">e</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -4.3em; left: 0.458em;"><span class="texatom" id="MathJax-Span-97"><span class="mrow" id="MathJax-Span-98"><span style="display: inline-block; position: relative; width: 0.762em; height: 0px;"><span style="position: absolute; clip: rect(3.546em, 1000.76em, 4.356em, -999.997em); top: -3.997em; left: 0em;"><span class="msubsup" id="MathJax-Span-99"><span style="display: inline-block; position: relative; width: 0.762em; height: 0px;"><span style="position: absolute; clip: rect(3.546em, 1000.51em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-100" style="font-size: 70.7%; font-family: MathJax_Math-italic;">w</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.895em; left: 0.509em;"><span class="texatom" id="MathJax-Span-101"><span class="mrow" id="MathJax-Span-102"><span style="display: inline-block; position: relative; width: 0.205em; height: 0px;"><span style="position: absolute; clip: rect(3.495em, 1000.21em, 4.255em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-103" style="font-size: 50%; font-family: MathJax_Math-italic;">j</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; clip: rect(0.863em, 1002.53em, 1.217em, -999.997em); top: -1.263em; left: 0em;"><span style="display: inline-block; overflow: hidden; vertical-align: 0em; border-top: 1.3px solid; width: 2.534em; height: 0px;"></span><span style="display: inline-block; width: 0px; height: 1.066em;"></span></span></span></span><span class="mo" id="MathJax-Span-104" style="font-family: MathJax_Main;">(</span><span class="mi" id="MathJax-Span-105" style="font-family: MathJax_Math-italic;">i</span><span class="mo" id="MathJax-Span-106" style="font-family: MathJax_Main; padding-left: 0.256em;">=</span><span class="mn" id="MathJax-Span-107" style="font-family: MathJax_Main; padding-left: 0.256em;">1</span><span class="mo" id="MathJax-Span-108" style="font-family: MathJax_Main;">,</span><span class="mn" id="MathJax-Span-109" style="font-family: MathJax_Main; padding-left: 0.154em;">2</span><span class="mo" id="MathJax-Span-110" style="font-family: MathJax_Main;">,</span><span class="mn" id="MathJax-Span-111" style="font-family: MathJax_Main; padding-left: 0.154em;">3</span><span class="mo" id="MathJax-Span-112" style="font-family: MathJax_Main;">,</span><span class="mn" id="MathJax-Span-113" style="font-family: MathJax_Main; padding-left: 0.154em;">4</span><span class="mo" id="MathJax-Span-114" style="font-family: MathJax_Main;">;</span><span class="mi" id="MathJax-Span-115" style="font-family: MathJax_Math-italic; padding-left: 0.154em;">j</span><span class="mo" id="MathJax-Span-116" style="font-family: MathJax_Main; padding-left: 0.256em;">=</span><span class="mn" id="MathJax-Span-117" style="font-family: MathJax_Main; padding-left: 0.256em;">1</span><span class="mo" id="MathJax-Span-118" style="font-family: MathJax_Main;">,</span><span class="mn" id="MathJax-Span-119" style="font-family: MathJax_Main; padding-left: 0.154em;">2</span><span class="mo" id="MathJax-Span-120" style="font-family: MathJax_Main;">,</span><span class="mn" id="MathJax-Span-121" style="font-family: MathJax_Main; padding-left: 0.154em;">3</span><span class="mo" id="MathJax-Span-122" style="font-family: MathJax_Main;">,</span><span class="mn" id="MathJax-Span-123" style="font-family: MathJax_Main; padding-left: 0.154em;">4</span><span class="mo" id="MathJax-Span-124" style="font-family: MathJax_Main;">)</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span><span style="display: inline-block; width: 0px; height: 4.609em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.71em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.71em;"></span></span></span><span style="display: inline-block; overflow: hidden; vertical-align: -1.701em; border-left: 0px solid; width: 0px; height: 3.927em;"></span></span></nobr><span class="MJX_Assistive_MathML MJX_Assistive_MathML_Block" role="presentation"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><mtable columnalign="right left right left right left right left right left right left" rowspacing="3pt" columnspacing="0em 2em 0em 2em 0em 2em 0em 2em 0em 2em 0em" displaystyle="true"><mtr><mtd></mtd><mtd><msub><mi>F</mi><mrow class="MJX-TeXAtom-ORD"><mi>f</mi><mi>f</mi></mrow></msub><mo>=</mo><msub><mi>α</mi><mrow class="MJX-TeXAtom-ORD"><mn>1</mn></mrow></msub><mo>∗</mo><msub><mi>F</mi><mrow class="MJX-TeXAtom-ORD"><mi>i</mi><mi>d</mi></mrow></msub><mo>+</mo><msub><mi>α</mi><mrow class="MJX-TeXAtom-ORD"><mn>2</mn></mrow></msub><mo>∗</mo><msub><mi>F</mi><mrow class="MJX-TeXAtom-ORD"><mi>d</mi><mi>c</mi><mi>o</mi><mi>n</mi><mi>v</mi></mrow></msub><mo>+</mo><msub><mi>α</mi><mrow class="MJX-TeXAtom-ORD"><mn>3</mn></mrow></msub><mo>∗</mo><msub><mi>F</mi><mrow class="MJX-TeXAtom-ORD"><mo movablelimits="true" form="prefix">max</mo></mrow></msub><mo>+</mo><msub><mi>α</mi><mrow class="MJX-TeXAtom-ORD"><mn>4</mn></mrow></msub><mo>∗</mo><msub><mi>F</mi><mrow class="MJX-TeXAtom-ORD"><mi>a</mi><mi>v</mi><mi>g</mi></mrow></msub></mtd></mtr><mtr><mtd></mtd><mtd><msub><mi>a</mi><mrow class="MJX-TeXAtom-ORD"><mi>i</mi></mrow></msub><mo>=</mo><mfrac><msup><mi>e</mi><mrow class="MJX-TeXAtom-ORD"><msub><mi>w</mi><mrow class="MJX-TeXAtom-ORD"><mi>i</mi></mrow></msub></mrow></msup><mrow><msub><mi mathvariant="normal">Σ</mi><mrow class="MJX-TeXAtom-ORD"><mi>j</mi></mrow></msub><msup><mi>e</mi><mrow class="MJX-TeXAtom-ORD"><msub><mi>w</mi><mrow class="MJX-TeXAtom-ORD"><mi>j</mi></mrow></msub></mrow></msup></mrow></mfrac><mo stretchy="false">(</mo><mi>i</mi><mo>=</mo><mn>1</mn><mo>,</mo><mn>2</mn><mo>,</mo><mn>3</mn><mo>,</mo><mn>4</mn><mo>;</mo><mi>j</mi><mo>=</mo><mn>1</mn><mo>,</mo><mn>2</mn><mo>,</mo><mn>3</mn><mo>,</mo><mn>4</mn><mo stretchy="false">)</mo></mtd></mtr></mtable></math></span></span></div><script type="math/tex; mode=display" id="MathJax-Element-1">\begin{aligned} &F_{f f}=\alpha_{1} * F_{i d}+\alpha_{2} * F_{dconv}+\alpha_{3} * F_{\max }+\alpha_{4} * F_{a v g} \\ &a_{i}=\frac{e^{w_{i}}}{\Sigma_{j} e^{w_{j}}}(i=1,2,3,4 ; j=1,2,3,4)\end{aligned}</script> 
   
  
 </span><span class="katex-html"><span class="base"><span class="strut" style="height: 4.1135em; vertical-align: -1.80675em;"></span><span class="mord"><span class="mtable"><span class="col-align-r"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 2.30675em;"><span class="" style="top: -4.80814em;"><span class="pstrut" style="height: 3.34139em;"></span><span class="mord"></span></span><span class="" style="top: -2.80675em;"><span class="pstrut" style="height: 3.34139em;"></span><span class="mord"></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 1.80675em;"><span class=""></span></span></span></span></span><span class="col-align-l"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 2.30675em;"><span class="" style="top: -4.80814em;"><span class="pstrut" style="height: 3.34139em;"></span><span class="mord"><span class="mord"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.13889em;">F</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: -0.13889em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.10764em;">f</span><span class="mord mathdefault mtight" style="margin-right: 0.10764em;">f</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: -0.0037em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">∗</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.13889em;">F</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: -0.13889em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">i</span><span class="mord mathdefault mtight">d</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: -0.0037em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">∗</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.13889em;">F</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: -0.13889em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">d</span><span class="mord mathdefault mtight">c</span><span class="mord mathdefault mtight">o</span><span class="mord mathdefault mtight">n</span><span class="mord mathdefault mtight" style="margin-right: 0.03588em;">v</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: -0.0037em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">3</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">∗</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.13889em;">F</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span class="" style="top: -2.15139em; margin-left: -0.13889em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.30139em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mop mtight">max</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: -0.0037em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">4</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">∗</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.13889em;">F</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span class="" style="top: -2.55em; margin-left: -0.13889em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">a</span><span class="mord mathdefault mtight" style="margin-right: 0.03588em;">v</span><span class="mord mathdefault mtight" style="margin-right: 0.03588em;">g</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span class=""></span></span></span></span></span></span></span></span><span class="" style="top: -2.80675em;"><span class="pstrut" style="height: 3.34139em;"></span><span class="mord"><span class="mord"></span><span class="mord"><span class="mord mathdefault">a</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.311664em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.34139em;"><span class="" style="top: -2.314em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord"><span class="mord">Σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.311664em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.05724em;">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span class=""></span></span></span></span></span></span><span class="mord"><span class="mord mathdefault">e</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.606462em;"><span class="" style="top: -3.00507em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.02691em;">w</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.328086em;"><span class="" style="top: -2.357em; margin-left: -0.02691em; margin-right: 0.0714286em;"><span class="pstrut" style="height: 2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.05724em;">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.281886em;"><span class=""></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="" style="top: -3.23em;"><span class="pstrut" style="height: 3em;"></span><span class="frac-line" style="border-bottom-width: 0.04em;"></span></span><span class="" style="top: -3.677em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord"><span class="mord mathdefault">e</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.664392em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.02691em;">w</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.328086em;"><span class="" style="top: -2.357em; margin-left: -0.02691em; margin-right: 0.0714286em;"><span class="pstrut" style="height: 2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">i</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.143em;"><span class=""></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.972108em;"><span class=""></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mopen">(</span><span class="mord mathdefault">i</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord">1</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord">2</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord">3</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord">4</span><span class="mpunct">;</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathdefault" style="margin-right: 0.05724em;">j</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord">1</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord">2</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord">3</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord">4</span><span class="mclose">)</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 1.80675em;"><span class=""></span></span></span></span></span></span></span></span></span></span></span></span><br> 其中,<span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    
     α
    
    
     i
    
   
  
  
   \alpha_i
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.58056em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.311664em;"><span class="" style="top: -2.55em; margin-left: -0.0037em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathdefault mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span>为归一化权重,<span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    Σ
   
   
    
     α
    
    
     i
    
   
   
    =
   
   
    1
   
  
  
   \Sigma\alpha_i=1
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.83333em; vertical-align: -0.15em;"></span><span class="mord">Σ</span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.311664em;"><span class="" style="top: -2.55em; margin-left: -0.0037em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathdefault mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">1</span></span></span></span></span>,<span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    
     w
    
    
     i
    
   
  
  
   w_i
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.58056em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.02691em;">w</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.311664em;"><span class="" style="top: -2.55em; margin-left: -0.02691em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathdefault mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span>为初始化权重系数。</p> 

结构分析:

  1. 对于一个输入的特征图,有四个分支
  2. 从上往下,第一个分支用的是Maxpooling进行最大池化提取局部特征
  3. 第二个分支用的是Avgpooling进行平均池化提取全局特征
  4. 第三个分支,原文中讲的是“用两组1×1卷积将特征的通道减半压缩,一是为了减少参数量防止过拟合,二是方便后续进行卷积特征拼接进行加性融合;接着在第一组1×1卷积后加入两组3×3卷积来替代5×5卷积后按通道进行拼接(Combine按通道拼接)。”原文将这个分支称作双卷积分支DConv,卷积能提取丰富特征,在拼接后接入一个SE注意力模块
  5. 第三个分支,是残差分支Identity,把输入直接跳跃连接加过去,保留原始特征

模型分析:

  1. 分析下模块结构,既然对于特征融合,最后的操作是Add,那么4个分支输出的特征图大小和维度是相同的!跳跃连接时原图的大小和维度都没有变,所以我们让四个分支的输出和原图大小保持一致
  2. 原文在3.2.2参数设置里面说:最大池化支路池化尺寸设为3,平均池化分支池化尺寸设为2
  3. 初始化各特征的权重全为1,使用nn.Parameter实现
  4. 输入图像的大小为3×84×84

AFP模块Pytorch实现

"""
Author: yida
Time is: 2021/11/17 15:45 
this Code:
1.实现<基于自适应特征融合与转换的小样本图像分类>中的自适应特征处理模块AFP
2.演示: nn.Parameter的使用
"""
import torch
import torch.nn as nn

class AFP(nn.Module):
def init(self):
super(AFP, self).init()

    self<span class="token punctuation">.</span>branch1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Sequential<span class="token punctuation">(</span>
        nn<span class="token punctuation">.</span>MaxPool2d<span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span>  <span class="token comment"># 1.最大池化分支,原文设置的尺寸大小为3, 未说明stride以及padding, 为与原图大小保持一致, 使用(3, 1, 1)</span>
    <span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>branch2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Sequential<span class="token punctuation">(</span>
        nn<span class="token punctuation">.</span>AvgPool2d<span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span>  <span class="token comment"># 2.平均池化分支, 原文设置的池化尺寸为2, 未说明stride以及padding, 为与原图大小保持一致, 使用(3, 1, 1)</span>
    <span class="token punctuation">)</span>

    self<span class="token punctuation">.</span>branch3_1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Sequential<span class="token punctuation">(</span>
        nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
        nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span>  <span class="token comment"># 3_1分支, 先用1×1卷积压缩通道维数, 然后使用两个3×3卷积进行特征提取, 由于通道数为3//2, 此时输出维度设为1</span>
        nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
    <span class="token punctuation">)</span>

    self<span class="token punctuation">.</span>branch3_2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Sequential<span class="token punctuation">(</span>
        nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span>  <span class="token comment"># 3_2分支, 由于1×1卷积压缩通道维数减半, 但是这儿维度为3, 上面用的1, 所以这儿输出维度设为2</span>
        nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span>
    <span class="token punctuation">)</span>
    <span class="token comment"># 注意力机制</span>
    self<span class="token punctuation">.</span>branch_SE <span class="token operator">=</span> SEblock<span class="token punctuation">(</span>channel<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">)</span>

    <span class="token comment"># 初始化可学习权重系数</span>
    <span class="token comment"># nn.Parameter 初始化的权重, 如果作用到网络中的话, 那么它会被添加到优化器更新的参数中, 优化器更新的时候会纠正Parameter的值, 使得向损失函数最小化的方向优化</span>

    self<span class="token punctuation">.</span>w <span class="token operator">=</span> nn<span class="token punctuation">.</span>Parameter<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>ones<span class="token punctuation">(</span><span class="token number">4</span><span class="token punctuation">)</span><span class="token punctuation">)</span>  <span class="token comment"># 4个分支, 每个分支设置一个自适应学习权重, 初始化为1, nn.Parameter需放入Tensor类型的数据</span>
    <span class="token comment"># self.w = nn.Parameter(torch.Tensor([0.5, 0.25, 0.15, 0.1]), requires_grad=False)  # 设置固定的权重系数, 不用归一化, 直接乘过去</span>

<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
    b1 <span class="token operator">=</span> self<span class="token punctuation">.</span>branch1<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
    b2 <span class="token operator">=</span> self<span class="token punctuation">.</span>branch2<span class="token punctuation">(</span>x<span class="token punctuation">)</span>

    b3_1 <span class="token operator">=</span> self<span class="token punctuation">.</span>branch3_1<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
    b3_2 <span class="token operator">=</span> self<span class="token punctuation">.</span>branch3_2<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
    b3_Combine <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">(</span>b3_1<span class="token punctuation">,</span> b3_2<span class="token punctuation">)</span><span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span>
    b3 <span class="token operator">=</span> self<span class="token punctuation">.</span>branch_SE<span class="token punctuation">(</span>b3_Combine<span class="token punctuation">)</span>

    b4 <span class="token operator">=</span> x

    <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"b1:"</span><span class="token punctuation">,</span> b1<span class="token punctuation">.</span>shape<span class="token punctuation">)</span>
    <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"b2:"</span><span class="token punctuation">,</span> b2<span class="token punctuation">.</span>shape<span class="token punctuation">)</span>
    <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"b3:"</span><span class="token punctuation">,</span> b3<span class="token punctuation">.</span>shape<span class="token punctuation">)</span>
    <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"b4:"</span><span class="token punctuation">,</span> b4<span class="token punctuation">.</span>shape<span class="token punctuation">)</span>

    <span class="token comment"># 归一化权重</span>
    w1 <span class="token operator">=</span> torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>self<span class="token punctuation">.</span>w<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">/</span> torch<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>self<span class="token punctuation">.</span>w<span class="token punctuation">)</span><span class="token punctuation">)</span>
    w2 <span class="token operator">=</span> torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>self<span class="token punctuation">.</span>w<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">/</span> torch<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>self<span class="token punctuation">.</span>w<span class="token punctuation">)</span><span class="token punctuation">)</span>
    w3 <span class="token operator">=</span> torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>self<span class="token punctuation">.</span>w<span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">/</span> torch<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>self<span class="token punctuation">.</span>w<span class="token punctuation">)</span><span class="token punctuation">)</span>
    w4 <span class="token operator">=</span> torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>self<span class="token punctuation">.</span>w<span class="token punctuation">[</span><span class="token number">3</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">/</span> torch<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>self<span class="token punctuation">.</span>w<span class="token punctuation">)</span><span class="token punctuation">)</span>

    <span class="token comment"># 多特征融合</span>
    x_out <span class="token operator">=</span> b1 <span class="token operator">*</span> w1 <span class="token operator">+</span> b2 <span class="token operator">*</span> w2 <span class="token operator">+</span> b3 <span class="token operator">*</span> w3 <span class="token operator">+</span> b4 <span class="token operator">*</span> w4
    <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"特征融合结果:"</span><span class="token punctuation">,</span> x_out<span class="token punctuation">.</span>shape<span class="token punctuation">)</span>
    <span class="token keyword">return</span> x_out

class SEblock(nn.Module): # 注意力机制模块
def init(self, channel, r=0.5): # channel为输入的维度, r为全连接层缩放比例->控制中间层个数
super(SEblock, self).init()
# 全局均值池化
self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
# 全连接层
self.fc = nn.Sequential(
nn.Linear(channel, int(channel r)), # int(channel * r)取整数
nn.ReLU(),
nn.Linear(int(channel r), channel),
nn.Sigmoid(),
)

<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token comment"># 对x进行分支计算权重, 进行全局均值池化</span>
    branch <span class="token operator">=</span> self<span class="token punctuation">.</span>global_avg_pool<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
    branch <span class="token operator">=</span> branch<span class="token punctuation">.</span>view<span class="token punctuation">(</span>branch<span class="token punctuation">.</span>size<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span>

    <span class="token comment"># 全连接层得到权重</span>
    weight <span class="token operator">=</span> self<span class="token punctuation">.</span>fc<span class="token punctuation">(</span>branch<span class="token punctuation">)</span>

    <span class="token comment"># 将维度为b, c的weight, reshape成b, c, 1, 1 与 输入x 相乘</span>
    h<span class="token punctuation">,</span> w <span class="token operator">=</span> weight<span class="token punctuation">.</span>shape
    weight <span class="token operator">=</span> torch<span class="token punctuation">.</span>reshape<span class="token punctuation">(</span>weight<span class="token punctuation">,</span> <span class="token punctuation">(</span>h<span class="token punctuation">,</span> w<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span>

    <span class="token comment"># 乘积获得结果</span>
    scale <span class="token operator">=</span> weight <span class="token operator">*</span> x
    <span class="token keyword">return</span> scale

if name == main:
model = AFP()
print(model)
inputs = torch.randn(10, 3, 84, 84)
print("输入维度为: ", inputs.shape)
outputs = model(inputs)
print("输出维度为: ", outputs.shape)

<span class="token comment"># 查看nn.Parameter中值的变化, 训练网络时, 更新优化器之后, 可以循环输出, 查看权重变化</span>
<span class="token keyword">for</span> name<span class="token punctuation">,</span> p <span class="token keyword">in</span> model<span class="token punctuation">.</span>named_parameters<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token keyword">if</span> name <span class="token operator">==</span> <span class="token string">'w'</span><span class="token punctuation">:</span>
        <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"特征权重: "</span><span class="token punctuation">,</span> name<span class="token punctuation">)</span>
        w0 <span class="token operator">=</span> <span class="token punctuation">(</span>torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>p<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">/</span> torch<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>p<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span>
        w1 <span class="token operator">=</span> <span class="token punctuation">(</span>torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>p<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">/</span> torch<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>p<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span>
        w2 <span class="token operator">=</span> <span class="token punctuation">(</span>torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>p<span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">/</span> torch<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>p<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span>
        w3 <span class="token operator">=</span> <span class="token punctuation">(</span>torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>p<span class="token punctuation">[</span><span class="token number">3</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">/</span> torch<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>p<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span>
        <span class="token keyword">print</span><span class="token punctuation">(</span>w0<span class="token punctuation">,</span> w1<span class="token punctuation">,</span> w2<span class="token punctuation">,</span> w3<span class="token punctuation">)</span>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116

nn.Parameter:上图特征融合中的权重系数

      w
     
     
      i
     
    
   
   
    w_i
   
  
 </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.58056em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.02691em;">w</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.311664em;"><span class="" style="top: -2.55em; margin-left: -0.02691em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathdefault mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span></font></p> 

nn.Parameter的使用:可学习权重设置

 
 
  • 1

在这里插入图片描述


更新记录

  • 2022年04月14日17:37:53

最近看有不少同学关注此博客,所以,我就找一个可以直接运行的手写数字识别代码,把可学习参数放进去,在训练时输出; 为了让大家能够对可学习参数的变化,有更好的理解。

手写数字识别代码

"""
Author: yida
Time is: 2022/3/6 09:30
this Code: 代码原文: https://www.cnblogs.com/wj-1314/p/9842719.html
- 代码: 手写数字识别, 源码参考上面的链接, 仅仅包含两个卷积层的手写数字识别 对每个卷积层设置一个权重系数w
- 可直接运行, torchvision会自动下载手写数字识别的数据集 存放在当前文件夹 ./mnist 模型保存为./model.pth
- 未实现测试功能 大家可以自行添加
- 为了便于大家更好的理解可学习参数
- 直接放到代码里面, 边训练边输出, 方便各位理解
"""
import os

import torch
import torch.nn as nn
import torchvision.datasets as normal_datasets
import torchvision.transforms as transforms
from torch.autograd import Variable

os.environ[“KMP_DUPLICATE_LIB_OK”] = “TRUE”

# 两层卷积
class CNN(nn.Module):
def init(self):
super(CNN, self).init()
# 使用序列工具快速构建
self.conv1 = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=5, padding=2),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(2))

    self.conv2 <span class="token operator">=</span> nn.Sequential<span class="token punctuation">(</span>
        nn.Conv2d<span class="token punctuation">(</span><span class="token number">16</span>, <span class="token number">32</span>, <span class="token assign-left variable">kernel_size</span><span class="token operator">=</span><span class="token number">5</span>, <span class="token assign-left variable">padding</span><span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span>,
        nn.BatchNorm2d<span class="token punctuation">(</span><span class="token number">32</span><span class="token punctuation">)</span>,
        nn.ReLU<span class="token punctuation">(</span><span class="token punctuation">)</span>,
        nn.MaxPool2d<span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">))</span>

    self.fc <span class="token operator">=</span> nn.Linear<span class="token punctuation">(</span><span class="token number">7</span> * <span class="token number">7</span> * <span class="token number">32</span>, <span class="token number">10</span><span class="token punctuation">)</span>
    self.w <span class="token operator">=</span> nn.Parameter<span class="token punctuation">(</span>torch.ones<span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">))</span>  <span class="token comment"># 初始化权重, 对2个卷积分别加一个权重</span>

def forward<span class="token punctuation">(</span>self, x<span class="token punctuation">)</span>:
    <span class="token comment"># 归一化权重</span>
    w1 <span class="token operator">=</span> torch.exp<span class="token punctuation">(</span>self.w<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span> / torch.sum<span class="token punctuation">(</span>torch.exp<span class="token punctuation">(</span>self.w<span class="token punctuation">))</span>
    w2 <span class="token operator">=</span> torch.exp<span class="token punctuation">(</span>self.w<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> / torch.sum<span class="token punctuation">(</span>torch.exp<span class="token punctuation">(</span>self.w<span class="token punctuation">))</span>
    out <span class="token operator">=</span> self.conv1<span class="token punctuation">(</span>x<span class="token punctuation">)</span> * w1
    out <span class="token operator">=</span> self.conv2<span class="token punctuation">(</span>out<span class="token punctuation">)</span> * w2
    out <span class="token operator">=</span> out.view<span class="token punctuation">(</span>out.size<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span>, -1<span class="token punctuation">)</span>  <span class="token comment"># reshape</span>
    out <span class="token operator">=</span> self.fc<span class="token punctuation">(</span>out<span class="token punctuation">)</span>
    <span class="token builtin class-name">return</span> out

# 将数据处理成Variable, 如果有GPU, 可以转成cuda形式
def get_variable(x):
x = Variable(x)
return x.cuda() if torch.cuda.is_available() else x

if name == main:

num_epochs <span class="token operator">=</span> <span class="token number">5</span>
batch_size <span class="token operator">=</span> <span class="token number">100</span>
learning_rate <span class="token operator">=</span> <span class="token number">0.001</span>

<span class="token comment"># 从torchvision.datasets中加载一些常用数据集</span>
train_dataset <span class="token operator">=</span> normal_datasets.MNIST<span class="token punctuation">(</span>
    <span class="token assign-left variable">root</span><span class="token operator">=</span><span class="token string">'./mnist/'</span>,  <span class="token comment"># 数据集保存路径</span>
    <span class="token assign-left variable">train</span><span class="token operator">=</span>True,  <span class="token comment"># 是否作为训练集</span>
    <span class="token assign-left variable">transform</span><span class="token operator">=</span>transforms.ToTensor<span class="token punctuation">(</span><span class="token punctuation">)</span>,  <span class="token comment"># 数据如何处理, 可以自己自定义</span>
    <span class="token assign-left variable">download</span><span class="token operator">=</span>True<span class="token punctuation">)</span>  <span class="token comment"># 路径下没有的话, 可以下载</span>

<span class="token comment"># 见数据加载器和batch</span>
test_dataset <span class="token operator">=</span> normal_datasets.MNIST<span class="token punctuation">(</span>root<span class="token operator">=</span><span class="token string">'./mnist/'</span>,
                                     <span class="token assign-left variable">train</span><span class="token operator">=</span>False,
                                     <span class="token assign-left variable">transform</span><span class="token operator">=</span>transforms.ToTensor<span class="token punctuation">(</span><span class="token punctuation">))</span>

train_loader <span class="token operator">=</span> torch.utils.data.DataLoader<span class="token punctuation">(</span>dataset<span class="token operator">=</span>train_dataset,
                                           <span class="token assign-left variable">batch_size</span><span class="token operator">=</span>batch_size,
                                           <span class="token assign-left variable">shuffle</span><span class="token operator">=</span>True<span class="token punctuation">)</span>

test_loader <span class="token operator">=</span> torch.utils.data.DataLoader<span class="token punctuation">(</span>dataset<span class="token operator">=</span>test_dataset,
                                          <span class="token assign-left variable">batch_size</span><span class="token operator">=</span>batch_size,
                                          <span class="token assign-left variable">shuffle</span><span class="token operator">=</span>False<span class="token punctuation">)</span>

model <span class="token operator">=</span> CNN<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token keyword">if</span> torch.cuda.is_available<span class="token punctuation">(</span><span class="token punctuation">)</span>:
    model <span class="token operator">=</span> model.cuda<span class="token punctuation">(</span><span class="token punctuation">)</span>

<span class="token comment"># 选择损失函数和优化方法</span>
loss_func <span class="token operator">=</span> nn.CrossEntropyLoss<span class="token punctuation">(</span><span class="token punctuation">)</span>
optimizer <span class="token operator">=</span> torch.optim.Adam<span class="token punctuation">(</span>model.parameters<span class="token punctuation">(</span><span class="token punctuation">)</span>, <span class="token assign-left variable">lr</span><span class="token operator">=</span>learning_rate<span class="token punctuation">)</span>

<span class="token keyword">for</span> <span class="token for-or-select variable">epoch</span> <span class="token keyword">in</span> range<span class="token punctuation">(</span>num_epochs<span class="token punctuation">)</span>:
    <span class="token keyword">for</span> i, <span class="token punctuation">(</span>images, labels<span class="token punctuation">)</span> <span class="token keyword">in</span> enumerate<span class="token punctuation">(</span>train_loader<span class="token punctuation">)</span>:
        images <span class="token operator">=</span> get_variable<span class="token punctuation">(</span>images<span class="token punctuation">)</span>
        labels <span class="token operator">=</span> get_variable<span class="token punctuation">(</span>labels<span class="token punctuation">)</span>
        outputs <span class="token operator">=</span> model<span class="token punctuation">(</span>images<span class="token punctuation">)</span>
        loss <span class="token operator">=</span> loss_func<span class="token punctuation">(</span>outputs, labels<span class="token punctuation">)</span>
        optimizer.zero_grad<span class="token punctuation">(</span><span class="token punctuation">)</span>
        loss.backward<span class="token punctuation">(</span><span class="token punctuation">)</span>
        optimizer.step<span class="token punctuation">(</span><span class="token punctuation">)</span>
        <span class="token keyword">if</span> <span class="token punctuation">(</span>i + <span class="token number">1</span><span class="token punctuation">)</span> % <span class="token number">100</span> <span class="token operator">==</span> <span class="token number">0</span>:
            print<span class="token punctuation">(</span><span class="token string">'Epoch [%d/%d], Iter [%d/%d] Loss: %.4f'</span>
                  % <span class="token punctuation">(</span>epoch + <span class="token number">1</span>, num_epochs, i + <span class="token number">1</span>, len<span class="token punctuation">(</span>train_dataset<span class="token punctuation">)</span> // batch_size, loss.item<span class="token punctuation">(</span><span class="token punctuation">))</span><span class="token punctuation">)</span>

            <span class="token comment"># 动态输出w权重变换</span>
            <span class="token keyword">for</span> name, p <span class="token keyword">in</span> model.named_parameters<span class="token punctuation">(</span><span class="token punctuation">)</span>:
                <span class="token keyword">if</span> name <span class="token operator">==</span> <span class="token string">'w'</span><span class="token builtin class-name">:</span>
                    print<span class="token punctuation">(</span><span class="token string">"特征权重: "</span>, name<span class="token punctuation">)</span>
                    w0 <span class="token operator">=</span> <span class="token punctuation">(</span>torch.exp<span class="token punctuation">(</span>p<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span> / torch.sum<span class="token punctuation">(</span>torch.exp<span class="token punctuation">(</span>p<span class="token punctuation">))</span><span class="token punctuation">)</span>.item<span class="token punctuation">(</span><span class="token punctuation">)</span>
                    w1 <span class="token operator">=</span> <span class="token punctuation">(</span>torch.exp<span class="token punctuation">(</span>p<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> / torch.sum<span class="token punctuation">(</span>torch.exp<span class="token punctuation">(</span>p<span class="token punctuation">))</span><span class="token punctuation">)</span>.item<span class="token punctuation">(</span><span class="token punctuation">)</span>
                    print<span class="token punctuation">(</span><span class="token string">"w0={} w1={}"</span>.format<span class="token punctuation">(</span>w0, w1<span class="token punctuation">))</span>
                    print<span class="token punctuation">(</span><span class="token string">""</span><span class="token punctuation">)</span>

<span class="token comment"># Save the Trained Model</span>
print<span class="token punctuation">(</span><span class="token string">"训练完成..."</span><span class="token punctuation">)</span>
torch.save<span class="token punctuation">(</span>model.state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span>, <span class="token string">'./model.pth'</span><span class="token punctuation">)</span>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118

【推荐阅读】

Pytorch-GPU安装教程大合集(Perfect完美系列)

  • 1
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是使用PyTorch编写自适应权重特征融合的示例代码: ```python import torch import torch.nn as nn import torch.nn.functional as F class FeatureFusion(nn.Module): def __init__(self, input_channels, output_channels): super(FeatureFusion, self).__init__() self.conv1 = nn.Conv2d(input_channels, output_channels, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(output_channels, output_channels, kernel_size=3, padding=1) self.conv3 = nn.Conv2d(output_channels, output_channels, kernel_size=3, padding=1) self.conv4 = nn.Conv2d(output_channels, output_channels, kernel_size=3, padding=1) self.relu = nn.ReLU(inplace=True) self.weight_net = nn.Sequential( nn.Conv2d(input_channels, output_channels, kernel_size=1), nn.BatchNorm2d(output_channels), nn.ReLU(inplace=True), nn.Conv2d(output_channels, output_channels, kernel_size=1), nn.BatchNorm2d(output_channels), nn.Sigmoid() ) def forward(self, x): w = self.weight_net(x) x1 = self.conv1(x) x1 = self.relu(x1) x1 = self.conv2(x1) x1 = self.relu(x1) x2 = self.conv3(x) x2 = self.relu(x2) x2 = self.conv4(x2) x2 = self.relu(x2) x = w * x1 + (1 - w) * x2 return x ``` 在这个示例中,我们定义了一个名为FeatureFusion的模块,它接受两个特征图作为输入,并以自适应的方式将它们融合在一起。该模块包含四个卷积层,两个用于处理第一个输入特征图,两个用于处理第二个输入特征图,并且一个权重网络,它生成一个0到1之间的权重,表示对应于两个输入特征图的每个位置使用哪个特征图的比例。最后,我们使用这个权重来将两个特征融合在一起。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值