2021年11月17日11:32:14
今天我们来完成Pytorch自适应可学习权重系数,在进行特征融合时,给不同特征图分配可学习的权重!
原文:基于自适应特征融合与转换的小样本图像分类(2021)
期刊:计算机工程与应用(中文核心、CSCD扩展版)
实现这篇论文里面多特征融合的分支!
实现自适应特征处理模块如下图所示:
特征融合公式如下:
F
f
f
=
α
1
∗
F
i
d
+
α
2
∗
F
d
c
o
n
v
+
α
3
∗
F
max
+
α
4
∗
F
a
v
g
a
i
=
e
w
i
Σ
j
e
w
j
(
i
=
1
,
2
,
3
,
4
;
j
=
1
,
2
,
3
,
4
)
<span class="MathJax_Preview" style="color: inherit; display: none;"></span><div class="MathJax_Display"><span class="MathJax MathJax_FullWidth" id="MathJax-Element-1-Frame" tabindex="0" data-mathml="<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><mtable columnalign="right left right left right left right left right left right left" rowspacing="3pt" columnspacing="0em 2em 0em 2em 0em 2em 0em 2em 0em 2em 0em" displaystyle="true"><mtr><mtd /><mtd><msub><mi>F</mi><mrow class="MJX-TeXAtom-ORD"><mi>f</mi><mi>f</mi></mrow></msub><mo>=</mo><msub><mi>&#x03B1;</mi><mrow class="MJX-TeXAtom-ORD"><mn>1</mn></mrow></msub><mo>&#x2217;</mo><msub><mi>F</mi><mrow class="MJX-TeXAtom-ORD"><mi>i</mi><mi>d</mi></mrow></msub><mo>+</mo><msub><mi>&#x03B1;</mi><mrow class="MJX-TeXAtom-ORD"><mn>2</mn></mrow></msub><mo>&#x2217;</mo><msub><mi>F</mi><mrow class="MJX-TeXAtom-ORD"><mi>d</mi><mi>c</mi><mi>o</mi><mi>n</mi><mi>v</mi></mrow></msub><mo>+</mo><msub><mi>&#x03B1;</mi><mrow class="MJX-TeXAtom-ORD"><mn>3</mn></mrow></msub><mo>&#x2217;</mo><msub><mi>F</mi><mrow class="MJX-TeXAtom-ORD"><mo movablelimits="true" form="prefix">max</mo></mrow></msub><mo>+</mo><msub><mi>&#x03B1;</mi><mrow class="MJX-TeXAtom-ORD"><mn>4</mn></mrow></msub><mo>&#x2217;</mo><msub><mi>F</mi><mrow class="MJX-TeXAtom-ORD"><mi>a</mi><mi>v</mi><mi>g</mi></mrow></msub></mtd></mtr><mtr><mtd /><mtd><msub><mi>a</mi><mrow class="MJX-TeXAtom-ORD"><mi>i</mi></mrow></msub><mo>=</mo><mfrac><msup><mi>e</mi><mrow class="MJX-TeXAtom-ORD"><msub><mi>w</mi><mrow class="MJX-TeXAtom-ORD"><mi>i</mi></mrow></msub></mrow></msup><mrow><msub><mi mathvariant="normal">&#x03A3;</mi><mrow class="MJX-TeXAtom-ORD"><mi>j</mi></mrow></msub><msup><mi>e</mi><mrow class="MJX-TeXAtom-ORD"><msub><mi>w</mi><mrow class="MJX-TeXAtom-ORD"><mi>j</mi></mrow></msub></mrow></msup></mrow></mfrac><mo stretchy="false">(</mo><mi>i</mi><mo>=</mo><mn>1</mn><mo>,</mo><mn>2</mn><mo>,</mo><mn>3</mn><mo>,</mo><mn>4</mn><mo>;</mo><mi>j</mi><mo>=</mo><mn>1</mn><mo>,</mo><mn>2</mn><mo>,</mo><mn>3</mn><mo>,</mo><mn>4</mn><mo stretchy="false">)</mo></mtd></mtr></mtable></math>" role="presentation" style="position: relative;"><nobr aria-hidden="true"><span class="math" id="MathJax-Span-1" style="width: 100%; display: inline-block; min-width: 22.53em;"><span style="display: inline-block; position: relative; width: 100%; height: 0px; font-size: 102%;"><span style="position: absolute; clip: rect(2.432em, 1022.07em, 6.482em, -999.997em); top: -4.705em; left: 0em; width: 100%;"><span class="mrow" id="MathJax-Span-2"><span style="display: inline-block; position: relative; width: 100%; height: 0px;"><span style="position: absolute; clip: rect(2.432em, 1022.07em, 6.482em, -999.997em); top: -4.705em; left: 50%; margin-left: -11.033em;"><span class="mtable" id="MathJax-Span-3" style="padding-left: 0.154em;"><span style="display: inline-block; position: relative; width: 21.922em; height: 0px;"><span style="position: absolute; clip: rect(2.534em, 1000em, 4.812em, -999.997em); top: -3.997em; left: 0em;"><span style="display: inline-block; position: relative; width: 0em; height: 0px;"><span style="position: absolute; clip: rect(3.85em, 1000em, 4.154em, -999.997em); top: -5.313em; right: 0em;"><span class="mtd" id="MathJax-Span-4"><span class="mrow" id="MathJax-Span-5"></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; clip: rect(3.85em, 1000em, 4.154em, -999.997em); top: -3.339em; right: 0em;"><span class="mtd" id="MathJax-Span-69"><span class="mrow" id="MathJax-Span-70"></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; clip: rect(2.483em, 1021.92em, 6.381em, -999.997em); top: -4.604em; left: 0em;"><span style="display: inline-block; position: relative; width: 21.922em; height: 0px;"><span style="position: absolute; width: 100%; clip: rect(3.192em, 1021.92em, 4.457em, -999.997em); top: -5.313em; left: 0em;"><span class="mtd" id="MathJax-Span-6"><span class="mrow" id="MathJax-Span-7"><span style="display: inline-block; position: relative; width: 100%; height: 0px;"><span style="position: absolute; clip: rect(3.192em, 1021.92em, 4.457em, -999.997em); top: -3.997em; left: 50%; margin-left: -10.932em;"><span class="msubsup" id="MathJax-Span-8"><span style="display: inline-block; position: relative; width: 1.521em; height: 0px;"><span style="position: absolute; clip: rect(3.192em, 1000.76em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-9" style="font-family: MathJax_Math-italic;">F<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.104em;"></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.845em; left: 0.661em;"><span class="texatom" id="MathJax-Span-10"><span class="mrow" id="MathJax-Span-11"><span style="display: inline-block; position: relative; width: 0.762em; height: 0px;"><span style="position: absolute; clip: rect(3.344em, 1000.76em, 4.305em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-12" style="font-size: 70.7%; font-family: MathJax_Math-italic;">f<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.053em;"></span></span><span class="mi" id="MathJax-Span-13" style="font-size: 70.7%; font-family: MathJax_Math-italic;">f<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.053em;"></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span class="mo" id="MathJax-Span-14" style="font-family: MathJax_Main; padding-left: 0.256em;">=</span><span class="msubsup" id="MathJax-Span-15" style="padding-left: 0.256em;"><span style="display: inline-block; position: relative; width: 1.066em; height: 0px;"><span style="position: absolute; clip: rect(3.394em, 1000.61em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-16" style="font-family: MathJax_Math-italic;">α</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.845em; left: 0.661em;"><span class="texatom" id="MathJax-Span-17"><span class="mrow" id="MathJax-Span-18"><span style="display: inline-block; position: relative; width: 0.357em; height: 0px;"><span style="position: absolute; clip: rect(3.394em, 1000.31em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mn" id="MathJax-Span-19" style="font-size: 70.7%; font-family: MathJax_Main;">1</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span class="mo" id="MathJax-Span-20" style="font-family: MathJax_Main; padding-left: 0.205em;">∗</span><span class="msubsup" id="MathJax-Span-21" style="padding-left: 0.205em;"><span style="display: inline-block; position: relative; width: 1.319em; height: 0px;"><span style="position: absolute; clip: rect(3.192em, 1000.76em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-22" style="font-family: MathJax_Math-italic;">F<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.104em;"></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.845em; left: 0.661em;"><span class="texatom" id="MathJax-Span-23"><span class="mrow" id="MathJax-Span-24"><span style="display: inline-block; position: relative; width: 0.61em; height: 0px;"><span style="position: absolute; clip: rect(3.344em, 1000.61em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-25" style="font-size: 70.7%; font-family: MathJax_Math-italic;">i</span><span class="mi" id="MathJax-Span-26" style="font-size: 70.7%; font-family: MathJax_Math-italic;">d<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span class="mo" id="MathJax-Span-27" style="font-family: MathJax_Main; padding-left: 0.205em;">+</span><span class="msubsup" id="MathJax-Span-28" style="padding-left: 0.205em;"><span style="display: inline-block; position: relative; width: 1.066em; height: 0px;"><span style="position: absolute; clip: rect(3.394em, 1000.61em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-29" style="font-family: MathJax_Math-italic;">α</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.845em; left: 0.661em;"><span class="texatom" id="MathJax-Span-30"><span class="mrow" id="MathJax-Span-31"><span style="display: inline-block; position: relative; width: 0.357em; height: 0px;"><span style="position: absolute; clip: rect(3.394em, 1000.31em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mn" id="MathJax-Span-32" style="font-size: 70.7%; font-family: MathJax_Main;">2</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span class="mo" id="MathJax-Span-33" style="font-family: MathJax_Main; padding-left: 0.205em;">∗</span><span class="msubsup" id="MathJax-Span-34" style="padding-left: 0.205em;"><span style="display: inline-block; position: relative; width: 2.483em; height: 0px;"><span style="position: absolute; clip: rect(3.192em, 1000.76em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-35" style="font-family: MathJax_Math-italic;">F<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.104em;"></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.845em; left: 0.661em;"><span class="texatom" id="MathJax-Span-36"><span class="mrow" id="MathJax-Span-37"><span style="display: inline-block; position: relative; width: 1.774em; height: 0px;"><span style="position: absolute; clip: rect(3.344em, 1001.77em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-38" style="font-size: 70.7%; font-family: MathJax_Math-italic;">d<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span class="mi" id="MathJax-Span-39" style="font-size: 70.7%; font-family: MathJax_Math-italic;">c</span><span class="mi" id="MathJax-Span-40" style="font-size: 70.7%; font-family: MathJax_Math-italic;">o</span><span class="mi" id="MathJax-Span-41" style="font-size: 70.7%; font-family: MathJax_Math-italic;">n</span><span class="mi" id="MathJax-Span-42" style="font-size: 70.7%; font-family: MathJax_Math-italic;">v</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span class="mo" id="MathJax-Span-43" style="font-family: MathJax_Main; padding-left: 0.205em;">+</span><span class="msubsup" id="MathJax-Span-44" style="padding-left: 0.205em;"><span style="display: inline-block; position: relative; width: 1.066em; height: 0px;"><span style="position: absolute; clip: rect(3.394em, 1000.61em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-45" style="font-family: MathJax_Math-italic;">α</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.845em; left: 0.661em;"><span class="texatom" id="MathJax-Span-46"><span class="mrow" id="MathJax-Span-47"><span style="display: inline-block; position: relative; width: 0.357em; height: 0px;"><span style="position: absolute; clip: rect(3.394em, 1000.31em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mn" id="MathJax-Span-48" style="font-size: 70.7%; font-family: MathJax_Main;">3</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span class="mo" id="MathJax-Span-49" style="font-family: MathJax_Main; padding-left: 0.205em;">∗</span><span class="msubsup" id="MathJax-Span-50" style="padding-left: 0.205em;"><span style="display: inline-block; position: relative; width: 2.027em; height: 0px;"><span style="position: absolute; clip: rect(3.192em, 1000.76em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-51" style="font-family: MathJax_Math-italic;">F<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.104em;"></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.845em; left: 0.661em;"><span class="texatom" id="MathJax-Span-52"><span class="mrow" id="MathJax-Span-53"><span style="display: inline-block; position: relative; width: 1.319em; height: 0px;"><span style="position: absolute; clip: rect(3.546em, 1001.32em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mo" id="MathJax-Span-54" style="font-size: 70.7%; font-family: MathJax_Main;">max</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span class="mo" id="MathJax-Span-55" style="font-family: MathJax_Main; padding-left: 0.205em;">+</span><span class="msubsup" id="MathJax-Span-56" style="padding-left: 0.205em;"><span style="display: inline-block; position: relative; width: 1.066em; height: 0px;"><span style="position: absolute; clip: rect(3.394em, 1000.61em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-57" style="font-family: MathJax_Math-italic;">α</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.845em; left: 0.661em;"><span class="texatom" id="MathJax-Span-58"><span class="mrow" id="MathJax-Span-59"><span style="display: inline-block; position: relative; width: 0.357em; height: 0px;"><span style="position: absolute; clip: rect(3.394em, 1000.36em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mn" id="MathJax-Span-60" style="font-size: 70.7%; font-family: MathJax_Main;">4</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span class="mo" id="MathJax-Span-61" style="font-family: MathJax_Main; padding-left: 0.205em;">∗</span><span class="msubsup" id="MathJax-Span-62" style="padding-left: 0.205em;"><span style="display: inline-block; position: relative; width: 1.774em; height: 0px;"><span style="position: absolute; clip: rect(3.192em, 1000.76em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-63" style="font-family: MathJax_Math-italic;">F<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.104em;"></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.845em; left: 0.661em;"><span class="texatom" id="MathJax-Span-64"><span class="mrow" id="MathJax-Span-65"><span style="display: inline-block; position: relative; width: 1.066em; height: 0px;"><span style="position: absolute; clip: rect(3.546em, 1001.07em, 4.305em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-66" style="font-size: 70.7%; font-family: MathJax_Math-italic;">a</span><span class="mi" id="MathJax-Span-67" style="font-size: 70.7%; font-family: MathJax_Math-italic;">v</span><span class="mi" id="MathJax-Span-68" style="font-size: 70.7%; font-family: MathJax_Math-italic;">g<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; width: 100%; clip: rect(2.483em, 1015.95em, 5.115em, -999.997em); top: -3.339em; left: 0em;"><span class="mtd" id="MathJax-Span-71"><span class="mrow" id="MathJax-Span-72"><span style="display: inline-block; position: relative; width: 100%; height: 0px;"><span style="position: absolute; clip: rect(2.483em, 1015.95em, 5.115em, -999.997em); top: -3.997em; left: 50%; margin-left: -7.996em;"><span class="msubsup" id="MathJax-Span-73"><span style="display: inline-block; position: relative; width: 0.863em; height: 0px;"><span style="position: absolute; clip: rect(3.394em, 1000.51em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-74" style="font-family: MathJax_Math-italic;">a</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.845em; left: 0.509em;"><span class="texatom" id="MathJax-Span-75"><span class="mrow" id="MathJax-Span-76"><span style="display: inline-block; position: relative; width: 0.256em; height: 0px;"><span style="position: absolute; clip: rect(3.394em, 1000.21em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-77" style="font-size: 70.7%; font-family: MathJax_Math-italic;">i</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span class="mo" id="MathJax-Span-78" style="font-family: MathJax_Main; padding-left: 0.256em;">=</span><span class="mfrac" id="MathJax-Span-79" style="padding-left: 0.256em;"><span style="display: inline-block; position: relative; width: 2.534em; height: 0px; margin-right: 0.104em; margin-left: 0.104em;"><span style="position: absolute; clip: rect(3.192em, 1001.27em, 4.154em, -999.997em); top: -4.655em; left: 50%; margin-left: -0.656em;"><span class="msubsup" id="MathJax-Span-80"><span style="display: inline-block; position: relative; width: 1.268em; height: 0px;"><span style="position: absolute; clip: rect(3.394em, 1000.41em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-81" style="font-family: MathJax_Math-italic;">e</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -4.351em; left: 0.458em;"><span class="texatom" id="MathJax-Span-82"><span class="mrow" id="MathJax-Span-83"><span style="display: inline-block; position: relative; width: 0.711em; height: 0px;"><span style="position: absolute; clip: rect(3.546em, 1000.71em, 4.255em, -999.997em); top: -3.997em; left: 0em;"><span class="msubsup" id="MathJax-Span-84"><span style="display: inline-block; position: relative; width: 0.711em; height: 0px;"><span style="position: absolute; clip: rect(3.546em, 1000.51em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-85" style="font-size: 70.7%; font-family: MathJax_Math-italic;">w</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.895em; left: 0.509em;"><span class="texatom" id="MathJax-Span-86"><span class="mrow" id="MathJax-Span-87"><span style="display: inline-block; position: relative; width: 0.154em; height: 0px;"><span style="position: absolute; clip: rect(3.495em, 1000.15em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-88" style="font-size: 50%; font-family: MathJax_Math-italic;">i</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; clip: rect(3.192em, 1002.38em, 4.457em, -999.997em); top: -3.288em; left: 50%; margin-left: -1.212em;"><span class="mrow" id="MathJax-Span-89"><span style="display: inline-block; position: relative; width: 2.382em; height: 0px;"><span style="position: absolute; clip: rect(3.192em, 1002.38em, 4.457em, -999.997em); top: -3.997em; left: 0em;"><span class="msubsup" id="MathJax-Span-90"><span style="display: inline-block; position: relative; width: 1.066em; height: 0px;"><span style="position: absolute; clip: rect(3.192em, 1000.66em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-91" style="font-family: MathJax_Main;">Σ</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.845em; left: 0.711em;"><span class="texatom" id="MathJax-Span-92"><span class="mrow" id="MathJax-Span-93"><span style="display: inline-block; position: relative; width: 0.306em; height: 0px;"><span style="position: absolute; clip: rect(3.394em, 1000.31em, 4.305em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-94" style="font-size: 70.7%; font-family: MathJax_Math-italic;">j</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span class="msubsup" id="MathJax-Span-95"><span style="display: inline-block; position: relative; width: 1.319em; height: 0px;"><span style="position: absolute; clip: rect(3.394em, 1000.41em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-96" style="font-family: MathJax_Math-italic;">e</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -4.3em; left: 0.458em;"><span class="texatom" id="MathJax-Span-97"><span class="mrow" id="MathJax-Span-98"><span style="display: inline-block; position: relative; width: 0.762em; height: 0px;"><span style="position: absolute; clip: rect(3.546em, 1000.76em, 4.356em, -999.997em); top: -3.997em; left: 0em;"><span class="msubsup" id="MathJax-Span-99"><span style="display: inline-block; position: relative; width: 0.762em; height: 0px;"><span style="position: absolute; clip: rect(3.546em, 1000.51em, 4.154em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-100" style="font-size: 70.7%; font-family: MathJax_Math-italic;">w</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; top: -3.895em; left: 0.509em;"><span class="texatom" id="MathJax-Span-101"><span class="mrow" id="MathJax-Span-102"><span style="display: inline-block; position: relative; width: 0.205em; height: 0px;"><span style="position: absolute; clip: rect(3.495em, 1000.21em, 4.255em, -999.997em); top: -3.997em; left: 0em;"><span class="mi" id="MathJax-Span-103" style="font-size: 50%; font-family: MathJax_Math-italic;">j</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span><span style="position: absolute; clip: rect(0.863em, 1002.53em, 1.217em, -999.997em); top: -1.263em; left: 0em;"><span style="display: inline-block; overflow: hidden; vertical-align: 0em; border-top: 1.3px solid; width: 2.534em; height: 0px;"></span><span style="display: inline-block; width: 0px; height: 1.066em;"></span></span></span></span><span class="mo" id="MathJax-Span-104" style="font-family: MathJax_Main;">(</span><span class="mi" id="MathJax-Span-105" style="font-family: MathJax_Math-italic;">i</span><span class="mo" id="MathJax-Span-106" style="font-family: MathJax_Main; padding-left: 0.256em;">=</span><span class="mn" id="MathJax-Span-107" style="font-family: MathJax_Main; padding-left: 0.256em;">1</span><span class="mo" id="MathJax-Span-108" style="font-family: MathJax_Main;">,</span><span class="mn" id="MathJax-Span-109" style="font-family: MathJax_Main; padding-left: 0.154em;">2</span><span class="mo" id="MathJax-Span-110" style="font-family: MathJax_Main;">,</span><span class="mn" id="MathJax-Span-111" style="font-family: MathJax_Main; padding-left: 0.154em;">3</span><span class="mo" id="MathJax-Span-112" style="font-family: MathJax_Main;">,</span><span class="mn" id="MathJax-Span-113" style="font-family: MathJax_Main; padding-left: 0.154em;">4</span><span class="mo" id="MathJax-Span-114" style="font-family: MathJax_Main;">;</span><span class="mi" id="MathJax-Span-115" style="font-family: MathJax_Math-italic; padding-left: 0.154em;">j</span><span class="mo" id="MathJax-Span-116" style="font-family: MathJax_Main; padding-left: 0.256em;">=</span><span class="mn" id="MathJax-Span-117" style="font-family: MathJax_Main; padding-left: 0.256em;">1</span><span class="mo" id="MathJax-Span-118" style="font-family: MathJax_Main;">,</span><span class="mn" id="MathJax-Span-119" style="font-family: MathJax_Main; padding-left: 0.154em;">2</span><span class="mo" id="MathJax-Span-120" style="font-family: MathJax_Main;">,</span><span class="mn" id="MathJax-Span-121" style="font-family: MathJax_Main; padding-left: 0.154em;">3</span><span class="mo" id="MathJax-Span-122" style="font-family: MathJax_Main;">,</span><span class="mn" id="MathJax-Span-123" style="font-family: MathJax_Main; padding-left: 0.154em;">4</span><span class="mo" id="MathJax-Span-124" style="font-family: MathJax_Main;">)</span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.002em;"></span></span></span><span style="display: inline-block; width: 0px; height: 4.609em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.71em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.71em;"></span></span></span><span style="display: inline-block; overflow: hidden; vertical-align: -1.701em; border-left: 0px solid; width: 0px; height: 3.927em;"></span></span></nobr><span class="MJX_Assistive_MathML MJX_Assistive_MathML_Block" role="presentation"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><mtable columnalign="right left right left right left right left right left right left" rowspacing="3pt" columnspacing="0em 2em 0em 2em 0em 2em 0em 2em 0em 2em 0em" displaystyle="true"><mtr><mtd></mtd><mtd><msub><mi>F</mi><mrow class="MJX-TeXAtom-ORD"><mi>f</mi><mi>f</mi></mrow></msub><mo>=</mo><msub><mi>α</mi><mrow class="MJX-TeXAtom-ORD"><mn>1</mn></mrow></msub><mo>∗</mo><msub><mi>F</mi><mrow class="MJX-TeXAtom-ORD"><mi>i</mi><mi>d</mi></mrow></msub><mo>+</mo><msub><mi>α</mi><mrow class="MJX-TeXAtom-ORD"><mn>2</mn></mrow></msub><mo>∗</mo><msub><mi>F</mi><mrow class="MJX-TeXAtom-ORD"><mi>d</mi><mi>c</mi><mi>o</mi><mi>n</mi><mi>v</mi></mrow></msub><mo>+</mo><msub><mi>α</mi><mrow class="MJX-TeXAtom-ORD"><mn>3</mn></mrow></msub><mo>∗</mo><msub><mi>F</mi><mrow class="MJX-TeXAtom-ORD"><mo movablelimits="true" form="prefix">max</mo></mrow></msub><mo>+</mo><msub><mi>α</mi><mrow class="MJX-TeXAtom-ORD"><mn>4</mn></mrow></msub><mo>∗</mo><msub><mi>F</mi><mrow class="MJX-TeXAtom-ORD"><mi>a</mi><mi>v</mi><mi>g</mi></mrow></msub></mtd></mtr><mtr><mtd></mtd><mtd><msub><mi>a</mi><mrow class="MJX-TeXAtom-ORD"><mi>i</mi></mrow></msub><mo>=</mo><mfrac><msup><mi>e</mi><mrow class="MJX-TeXAtom-ORD"><msub><mi>w</mi><mrow class="MJX-TeXAtom-ORD"><mi>i</mi></mrow></msub></mrow></msup><mrow><msub><mi mathvariant="normal">Σ</mi><mrow class="MJX-TeXAtom-ORD"><mi>j</mi></mrow></msub><msup><mi>e</mi><mrow class="MJX-TeXAtom-ORD"><msub><mi>w</mi><mrow class="MJX-TeXAtom-ORD"><mi>j</mi></mrow></msub></mrow></msup></mrow></mfrac><mo stretchy="false">(</mo><mi>i</mi><mo>=</mo><mn>1</mn><mo>,</mo><mn>2</mn><mo>,</mo><mn>3</mn><mo>,</mo><mn>4</mn><mo>;</mo><mi>j</mi><mo>=</mo><mn>1</mn><mo>,</mo><mn>2</mn><mo>,</mo><mn>3</mn><mo>,</mo><mn>4</mn><mo stretchy="false">)</mo></mtd></mtr></mtable></math></span></span></div><script type="math/tex; mode=display" id="MathJax-Element-1">\begin{aligned} &F_{f f}=\alpha_{1} * F_{i d}+\alpha_{2} * F_{dconv}+\alpha_{3} * F_{\max }+\alpha_{4} * F_{a v g} \\ &a_{i}=\frac{e^{w_{i}}}{\Sigma_{j} e^{w_{j}}}(i=1,2,3,4 ; j=1,2,3,4)\end{aligned}</script>
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 4.1135em; vertical-align: -1.80675em;"></span><span class="mord"><span class="mtable"><span class="col-align-r"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 2.30675em;"><span class="" style="top: -4.80814em;"><span class="pstrut" style="height: 3.34139em;"></span><span class="mord"></span></span><span class="" style="top: -2.80675em;"><span class="pstrut" style="height: 3.34139em;"></span><span class="mord"></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 1.80675em;"><span class=""></span></span></span></span></span><span class="col-align-l"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 2.30675em;"><span class="" style="top: -4.80814em;"><span class="pstrut" style="height: 3.34139em;"></span><span class="mord"><span class="mord"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.13889em;">F</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: -0.13889em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.10764em;">f</span><span class="mord mathdefault mtight" style="margin-right: 0.10764em;">f</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: -0.0037em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">∗</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.13889em;">F</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: -0.13889em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">i</span><span class="mord mathdefault mtight">d</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: -0.0037em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">∗</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.13889em;">F</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: -0.13889em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">d</span><span class="mord mathdefault mtight">c</span><span class="mord mathdefault mtight">o</span><span class="mord mathdefault mtight">n</span><span class="mord mathdefault mtight" style="margin-right: 0.03588em;">v</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: -0.0037em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">3</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">∗</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.13889em;">F</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span class="" style="top: -2.15139em; margin-left: -0.13889em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.30139em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mop mtight">max</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: -0.0037em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">4</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">∗</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.13889em;">F</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span class="" style="top: -2.55em; margin-left: -0.13889em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">a</span><span class="mord mathdefault mtight" style="margin-right: 0.03588em;">v</span><span class="mord mathdefault mtight" style="margin-right: 0.03588em;">g</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span class=""></span></span></span></span></span></span></span></span><span class="" style="top: -2.80675em;"><span class="pstrut" style="height: 3.34139em;"></span><span class="mord"><span class="mord"></span><span class="mord"><span class="mord mathdefault">a</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.311664em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.34139em;"><span class="" style="top: -2.314em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord"><span class="mord">Σ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.311664em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.05724em;">j</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span class=""></span></span></span></span></span></span><span class="mord"><span class="mord mathdefault">e</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.606462em;"><span class="" style="top: -3.00507em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.02691em;">w</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.328086em;"><span class="" style="top: -2.357em; margin-left: -0.02691em; margin-right: 0.0714286em;"><span class="pstrut" style="height: 2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.05724em;">j</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.281886em;"><span class=""></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="" style="top: -3.23em;"><span class="pstrut" style="height: 3em;"></span><span class="frac-line" style="border-bottom-width: 0.04em;"></span></span><span class="" style="top: -3.677em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord"><span class="mord mathdefault">e</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.664392em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mathdefault mtight" style="margin-right: 0.02691em;">w</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.328086em;"><span class="" style="top: -2.357em; margin-left: -0.02691em; margin-right: 0.0714286em;"><span class="pstrut" style="height: 2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.143em;"><span class=""></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.972108em;"><span class=""></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mopen">(</span><span class="mord mathdefault">i</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord">1</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord">2</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord">3</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord">4</span><span class="mpunct">;</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathdefault" style="margin-right: 0.05724em;">j</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord">1</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord">2</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord">3</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord">4</span><span class="mclose">)</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 1.80675em;"><span class=""></span></span></span></span></span></span></span></span></span></span></span></span><br> 其中,<span class="katex--inline"><span class="katex"><span class="katex-mathml">
α
i
\alpha_i
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.58056em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.311664em;"><span class="" style="top: -2.55em; margin-left: -0.0037em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathdefault mtight">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span>为归一化权重,<span class="katex--inline"><span class="katex"><span class="katex-mathml">
Σ
α
i
=
1
\Sigma\alpha_i=1
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.83333em; vertical-align: -0.15em;"></span><span class="mord">Σ</span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.0037em;">α</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.311664em;"><span class="" style="top: -2.55em; margin-left: -0.0037em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathdefault mtight">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">1</span></span></span></span></span>,<span class="katex--inline"><span class="katex"><span class="katex-mathml">
w
i
w_i
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.58056em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.02691em;">w</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.311664em;"><span class="" style="top: -2.55em; margin-left: -0.02691em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathdefault mtight">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span>为初始化权重系数。</p>
结构分析:
- 对于一个输入的特征图,有四个分支
- 从上往下,第一个分支用的是Maxpooling进行最大池化提取局部特征
- 第二个分支用的是Avgpooling进行平均池化提取全局特征
- 第三个分支,原文中讲的是“用两组1×1卷积将特征的通道减半压缩,一是为了减少参数量防止过拟合,二是方便后续进行卷积特征拼接进行加性融合;接着在第一组1×1卷积后加入两组3×3卷积来替代5×5卷积后按通道进行拼接(Combine按通道拼接)。”原文将这个分支称作双卷积分支DConv,卷积能提取丰富特征,在拼接后接入一个SE注意力模块。
- 第三个分支,是残差分支Identity,把输入直接跳跃连接加过去,保留原始特征
模型分析:
- 分析下模块结构,既然对于特征融合,最后的操作是Add,那么4个分支输出的特征图大小和维度是相同的!跳跃连接时原图的大小和维度都没有变,所以我们让四个分支的输出和原图大小保持一致
- 原文在3.2.2参数设置里面说:最大池化支路池化尺寸设为3,平均池化分支池化尺寸设为2
- 初始化各特征的权重全为1,使用nn.Parameter实现
- 输入图像的大小为3×84×84
AFP模块Pytorch实现
"""
Author: yida
Time is: 2021/11/17 15:45
this Code:
1.实现<基于自适应特征融合与转换的小样本图像分类>中的自适应特征处理模块AFP
2.演示: nn.Parameter的使用
"""
import torch
import torch.nn as nn
class AFP(nn.Module):
def init(self):
super(AFP, self).init()
self<span class="token punctuation">.</span>branch1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Sequential<span class="token punctuation">(</span>
nn<span class="token punctuation">.</span>MaxPool2d<span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token comment"># 1.最大池化分支,原文设置的尺寸大小为3, 未说明stride以及padding, 为与原图大小保持一致, 使用(3, 1, 1)</span>
<span class="token punctuation">)</span>
self<span class="token punctuation">.</span>branch2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Sequential<span class="token punctuation">(</span>
nn<span class="token punctuation">.</span>AvgPool2d<span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token comment"># 2.平均池化分支, 原文设置的池化尺寸为2, 未说明stride以及padding, 为与原图大小保持一致, 使用(3, 1, 1)</span>
<span class="token punctuation">)</span>
self<span class="token punctuation">.</span>branch3_1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Sequential<span class="token punctuation">(</span>
nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token comment"># 3_1分支, 先用1×1卷积压缩通道维数, 然后使用两个3×3卷积进行特征提取, 由于通道数为3//2, 此时输出维度设为1</span>
nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
<span class="token punctuation">)</span>
self<span class="token punctuation">.</span>branch3_2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Sequential<span class="token punctuation">(</span>
nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token comment"># 3_2分支, 由于1×1卷积压缩通道维数减半, 但是这儿维度为3, 上面用的1, 所以这儿输出维度设为2</span>
nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span>
<span class="token punctuation">)</span>
<span class="token comment"># 注意力机制</span>
self<span class="token punctuation">.</span>branch_SE <span class="token operator">=</span> SEblock<span class="token punctuation">(</span>channel<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">)</span>
<span class="token comment"># 初始化可学习权重系数</span>
<span class="token comment"># nn.Parameter 初始化的权重, 如果作用到网络中的话, 那么它会被添加到优化器更新的参数中, 优化器更新的时候会纠正Parameter的值, 使得向损失函数最小化的方向优化</span>
self<span class="token punctuation">.</span>w <span class="token operator">=</span> nn<span class="token punctuation">.</span>Parameter<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>ones<span class="token punctuation">(</span><span class="token number">4</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token comment"># 4个分支, 每个分支设置一个自适应学习权重, 初始化为1, nn.Parameter需放入Tensor类型的数据</span>
<span class="token comment"># self.w = nn.Parameter(torch.Tensor([0.5, 0.25, 0.15, 0.1]), requires_grad=False) # 设置固定的权重系数, 不用归一化, 直接乘过去</span>
<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
b1 <span class="token operator">=</span> self<span class="token punctuation">.</span>branch1<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
b2 <span class="token operator">=</span> self<span class="token punctuation">.</span>branch2<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
b3_1 <span class="token operator">=</span> self<span class="token punctuation">.</span>branch3_1<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
b3_2 <span class="token operator">=</span> self<span class="token punctuation">.</span>branch3_2<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
b3_Combine <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">(</span>b3_1<span class="token punctuation">,</span> b3_2<span class="token punctuation">)</span><span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span>
b3 <span class="token operator">=</span> self<span class="token punctuation">.</span>branch_SE<span class="token punctuation">(</span>b3_Combine<span class="token punctuation">)</span>
b4 <span class="token operator">=</span> x
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"b1:"</span><span class="token punctuation">,</span> b1<span class="token punctuation">.</span>shape<span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"b2:"</span><span class="token punctuation">,</span> b2<span class="token punctuation">.</span>shape<span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"b3:"</span><span class="token punctuation">,</span> b3<span class="token punctuation">.</span>shape<span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"b4:"</span><span class="token punctuation">,</span> b4<span class="token punctuation">.</span>shape<span class="token punctuation">)</span>
<span class="token comment"># 归一化权重</span>
w1 <span class="token operator">=</span> torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>self<span class="token punctuation">.</span>w<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">/</span> torch<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>self<span class="token punctuation">.</span>w<span class="token punctuation">)</span><span class="token punctuation">)</span>
w2 <span class="token operator">=</span> torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>self<span class="token punctuation">.</span>w<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">/</span> torch<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>self<span class="token punctuation">.</span>w<span class="token punctuation">)</span><span class="token punctuation">)</span>
w3 <span class="token operator">=</span> torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>self<span class="token punctuation">.</span>w<span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">/</span> torch<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>self<span class="token punctuation">.</span>w<span class="token punctuation">)</span><span class="token punctuation">)</span>
w4 <span class="token operator">=</span> torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>self<span class="token punctuation">.</span>w<span class="token punctuation">[</span><span class="token number">3</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">/</span> torch<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>self<span class="token punctuation">.</span>w<span class="token punctuation">)</span><span class="token punctuation">)</span>
<span class="token comment"># 多特征融合</span>
x_out <span class="token operator">=</span> b1 <span class="token operator">*</span> w1 <span class="token operator">+</span> b2 <span class="token operator">*</span> w2 <span class="token operator">+</span> b3 <span class="token operator">*</span> w3 <span class="token operator">+</span> b4 <span class="token operator">*</span> w4
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"特征融合结果:"</span><span class="token punctuation">,</span> x_out<span class="token punctuation">.</span>shape<span class="token punctuation">)</span>
<span class="token keyword">return</span> x_out
class SEblock(nn.Module): # 注意力机制模块
def init(self, channel, r=0.5): # channel为输入的维度, r为全连接层缩放比例->控制中间层个数
super(SEblock, self).init()
# 全局均值池化
self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
# 全连接层
self.fc = nn.Sequential(
nn.Linear(channel, int(channel r)), # int(channel * r)取整数
nn.ReLU(),
nn.Linear(int(channel r), channel),
nn.Sigmoid(),
)
<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token comment"># 对x进行分支计算权重, 进行全局均值池化</span>
branch <span class="token operator">=</span> self<span class="token punctuation">.</span>global_avg_pool<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
branch <span class="token operator">=</span> branch<span class="token punctuation">.</span>view<span class="token punctuation">(</span>branch<span class="token punctuation">.</span>size<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span>
<span class="token comment"># 全连接层得到权重</span>
weight <span class="token operator">=</span> self<span class="token punctuation">.</span>fc<span class="token punctuation">(</span>branch<span class="token punctuation">)</span>
<span class="token comment"># 将维度为b, c的weight, reshape成b, c, 1, 1 与 输入x 相乘</span>
h<span class="token punctuation">,</span> w <span class="token operator">=</span> weight<span class="token punctuation">.</span>shape
weight <span class="token operator">=</span> torch<span class="token punctuation">.</span>reshape<span class="token punctuation">(</span>weight<span class="token punctuation">,</span> <span class="token punctuation">(</span>h<span class="token punctuation">,</span> w<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
<span class="token comment"># 乘积获得结果</span>
scale <span class="token operator">=</span> weight <span class="token operator">*</span> x
<span class="token keyword">return</span> scale
if name == ‘main’:
model = AFP()
print(model)
inputs = torch.randn(10, 3, 84, 84)
print("输入维度为: ", inputs.shape)
outputs = model(inputs)
print("输出维度为: ", outputs.shape)
<span class="token comment"># 查看nn.Parameter中值的变化, 训练网络时, 更新优化器之后, 可以循环输出, 查看权重变化</span>
<span class="token keyword">for</span> name<span class="token punctuation">,</span> p <span class="token keyword">in</span> model<span class="token punctuation">.</span>named_parameters<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">if</span> name <span class="token operator">==</span> <span class="token string">'w'</span><span class="token punctuation">:</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"特征权重: "</span><span class="token punctuation">,</span> name<span class="token punctuation">)</span>
w0 <span class="token operator">=</span> <span class="token punctuation">(</span>torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>p<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">/</span> torch<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>p<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span>
w1 <span class="token operator">=</span> <span class="token punctuation">(</span>torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>p<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">/</span> torch<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>p<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span>
w2 <span class="token operator">=</span> <span class="token punctuation">(</span>torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>p<span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">/</span> torch<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>p<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span>
w3 <span class="token operator">=</span> <span class="token punctuation">(</span>torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>p<span class="token punctuation">[</span><span class="token number">3</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">/</span> torch<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>p<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span>w0<span class="token punctuation">,</span> w1<span class="token punctuation">,</span> w2<span class="token punctuation">,</span> w3<span class="token punctuation">)</span>
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
- 98
- 99
- 100
- 101
- 102
- 103
- 104
- 105
- 106
- 107
- 108
- 109
- 110
- 111
- 112
- 113
- 114
- 115
- 116
nn.Parameter:上图特征融合中的权重系数
w
i
w_i
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.58056em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.02691em;">w</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.311664em;"><span class="" style="top: -2.55em; margin-left: -0.02691em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathdefault mtight">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span></font></p>
nn.Parameter的使用:可学习权重设置
- 1
更新记录
- 2022年04月14日17:37:53
最近看有不少同学关注此博客,所以,我就找一个可以直接运行的手写数字识别代码,把可学习参数放进去,在训练时输出; 为了让大家能够对可学习参数的变化,有更好的理解。
手写数字识别代码
"""
Author: yida
Time is: 2022/3/6 09:30
this Code: 代码原文: https://www.cnblogs.com/wj-1314/p/9842719.html
- 代码: 手写数字识别, 源码参考上面的链接, 仅仅包含两个卷积层的手写数字识别 对每个卷积层设置一个权重系数w
- 可直接运行, torchvision会自动下载手写数字识别的数据集 存放在当前文件夹 ./mnist 模型保存为./model.pth
- 未实现测试功能 大家可以自行添加
- 为了便于大家更好的理解可学习参数
- 直接放到代码里面, 边训练边输出, 方便各位理解
"""
import os
import torch
import torch.nn as nn
import torchvision.datasets as normal_datasets
import torchvision.transforms as transforms
from torch.autograd import Variable
os.environ[“KMP_DUPLICATE_LIB_OK”] = “TRUE”
# 两层卷积
class CNN(nn.Module):
def init(self):
super(CNN, self).init()
# 使用序列工具快速构建
self.conv1 = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=5, padding=2),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(2))
self.conv2 <span class="token operator">=</span> nn.Sequential<span class="token punctuation">(</span>
nn.Conv2d<span class="token punctuation">(</span><span class="token number">16</span>, <span class="token number">32</span>, <span class="token assign-left variable">kernel_size</span><span class="token operator">=</span><span class="token number">5</span>, <span class="token assign-left variable">padding</span><span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span>,
nn.BatchNorm2d<span class="token punctuation">(</span><span class="token number">32</span><span class="token punctuation">)</span>,
nn.ReLU<span class="token punctuation">(</span><span class="token punctuation">)</span>,
nn.MaxPool2d<span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">))</span>
self.fc <span class="token operator">=</span> nn.Linear<span class="token punctuation">(</span><span class="token number">7</span> * <span class="token number">7</span> * <span class="token number">32</span>, <span class="token number">10</span><span class="token punctuation">)</span>
self.w <span class="token operator">=</span> nn.Parameter<span class="token punctuation">(</span>torch.ones<span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">))</span> <span class="token comment"># 初始化权重, 对2个卷积分别加一个权重</span>
def forward<span class="token punctuation">(</span>self, x<span class="token punctuation">)</span>:
<span class="token comment"># 归一化权重</span>
w1 <span class="token operator">=</span> torch.exp<span class="token punctuation">(</span>self.w<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span> / torch.sum<span class="token punctuation">(</span>torch.exp<span class="token punctuation">(</span>self.w<span class="token punctuation">))</span>
w2 <span class="token operator">=</span> torch.exp<span class="token punctuation">(</span>self.w<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> / torch.sum<span class="token punctuation">(</span>torch.exp<span class="token punctuation">(</span>self.w<span class="token punctuation">))</span>
out <span class="token operator">=</span> self.conv1<span class="token punctuation">(</span>x<span class="token punctuation">)</span> * w1
out <span class="token operator">=</span> self.conv2<span class="token punctuation">(</span>out<span class="token punctuation">)</span> * w2
out <span class="token operator">=</span> out.view<span class="token punctuation">(</span>out.size<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span>, -1<span class="token punctuation">)</span> <span class="token comment"># reshape</span>
out <span class="token operator">=</span> self.fc<span class="token punctuation">(</span>out<span class="token punctuation">)</span>
<span class="token builtin class-name">return</span> out
# 将数据处理成Variable, 如果有GPU, 可以转成cuda形式
def get_variable(x):
x = Variable(x)
return x.cuda() if torch.cuda.is_available() else x
if name == ‘main’:
num_epochs <span class="token operator">=</span> <span class="token number">5</span>
batch_size <span class="token operator">=</span> <span class="token number">100</span>
learning_rate <span class="token operator">=</span> <span class="token number">0.001</span>
<span class="token comment"># 从torchvision.datasets中加载一些常用数据集</span>
train_dataset <span class="token operator">=</span> normal_datasets.MNIST<span class="token punctuation">(</span>
<span class="token assign-left variable">root</span><span class="token operator">=</span><span class="token string">'./mnist/'</span>, <span class="token comment"># 数据集保存路径</span>
<span class="token assign-left variable">train</span><span class="token operator">=</span>True, <span class="token comment"># 是否作为训练集</span>
<span class="token assign-left variable">transform</span><span class="token operator">=</span>transforms.ToTensor<span class="token punctuation">(</span><span class="token punctuation">)</span>, <span class="token comment"># 数据如何处理, 可以自己自定义</span>
<span class="token assign-left variable">download</span><span class="token operator">=</span>True<span class="token punctuation">)</span> <span class="token comment"># 路径下没有的话, 可以下载</span>
<span class="token comment"># 见数据加载器和batch</span>
test_dataset <span class="token operator">=</span> normal_datasets.MNIST<span class="token punctuation">(</span>root<span class="token operator">=</span><span class="token string">'./mnist/'</span>,
<span class="token assign-left variable">train</span><span class="token operator">=</span>False,
<span class="token assign-left variable">transform</span><span class="token operator">=</span>transforms.ToTensor<span class="token punctuation">(</span><span class="token punctuation">))</span>
train_loader <span class="token operator">=</span> torch.utils.data.DataLoader<span class="token punctuation">(</span>dataset<span class="token operator">=</span>train_dataset,
<span class="token assign-left variable">batch_size</span><span class="token operator">=</span>batch_size,
<span class="token assign-left variable">shuffle</span><span class="token operator">=</span>True<span class="token punctuation">)</span>
test_loader <span class="token operator">=</span> torch.utils.data.DataLoader<span class="token punctuation">(</span>dataset<span class="token operator">=</span>test_dataset,
<span class="token assign-left variable">batch_size</span><span class="token operator">=</span>batch_size,
<span class="token assign-left variable">shuffle</span><span class="token operator">=</span>False<span class="token punctuation">)</span>
model <span class="token operator">=</span> CNN<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token keyword">if</span> torch.cuda.is_available<span class="token punctuation">(</span><span class="token punctuation">)</span>:
model <span class="token operator">=</span> model.cuda<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token comment"># 选择损失函数和优化方法</span>
loss_func <span class="token operator">=</span> nn.CrossEntropyLoss<span class="token punctuation">(</span><span class="token punctuation">)</span>
optimizer <span class="token operator">=</span> torch.optim.Adam<span class="token punctuation">(</span>model.parameters<span class="token punctuation">(</span><span class="token punctuation">)</span>, <span class="token assign-left variable">lr</span><span class="token operator">=</span>learning_rate<span class="token punctuation">)</span>
<span class="token keyword">for</span> <span class="token for-or-select variable">epoch</span> <span class="token keyword">in</span> range<span class="token punctuation">(</span>num_epochs<span class="token punctuation">)</span>:
<span class="token keyword">for</span> i, <span class="token punctuation">(</span>images, labels<span class="token punctuation">)</span> <span class="token keyword">in</span> enumerate<span class="token punctuation">(</span>train_loader<span class="token punctuation">)</span>:
images <span class="token operator">=</span> get_variable<span class="token punctuation">(</span>images<span class="token punctuation">)</span>
labels <span class="token operator">=</span> get_variable<span class="token punctuation">(</span>labels<span class="token punctuation">)</span>
outputs <span class="token operator">=</span> model<span class="token punctuation">(</span>images<span class="token punctuation">)</span>
loss <span class="token operator">=</span> loss_func<span class="token punctuation">(</span>outputs, labels<span class="token punctuation">)</span>
optimizer.zero_grad<span class="token punctuation">(</span><span class="token punctuation">)</span>
loss.backward<span class="token punctuation">(</span><span class="token punctuation">)</span>
optimizer.step<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token keyword">if</span> <span class="token punctuation">(</span>i + <span class="token number">1</span><span class="token punctuation">)</span> % <span class="token number">100</span> <span class="token operator">==</span> <span class="token number">0</span>:
print<span class="token punctuation">(</span><span class="token string">'Epoch [%d/%d], Iter [%d/%d] Loss: %.4f'</span>
% <span class="token punctuation">(</span>epoch + <span class="token number">1</span>, num_epochs, i + <span class="token number">1</span>, len<span class="token punctuation">(</span>train_dataset<span class="token punctuation">)</span> // batch_size, loss.item<span class="token punctuation">(</span><span class="token punctuation">))</span><span class="token punctuation">)</span>
<span class="token comment"># 动态输出w权重变换</span>
<span class="token keyword">for</span> name, p <span class="token keyword">in</span> model.named_parameters<span class="token punctuation">(</span><span class="token punctuation">)</span>:
<span class="token keyword">if</span> name <span class="token operator">==</span> <span class="token string">'w'</span><span class="token builtin class-name">:</span>
print<span class="token punctuation">(</span><span class="token string">"特征权重: "</span>, name<span class="token punctuation">)</span>
w0 <span class="token operator">=</span> <span class="token punctuation">(</span>torch.exp<span class="token punctuation">(</span>p<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span> / torch.sum<span class="token punctuation">(</span>torch.exp<span class="token punctuation">(</span>p<span class="token punctuation">))</span><span class="token punctuation">)</span>.item<span class="token punctuation">(</span><span class="token punctuation">)</span>
w1 <span class="token operator">=</span> <span class="token punctuation">(</span>torch.exp<span class="token punctuation">(</span>p<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> / torch.sum<span class="token punctuation">(</span>torch.exp<span class="token punctuation">(</span>p<span class="token punctuation">))</span><span class="token punctuation">)</span>.item<span class="token punctuation">(</span><span class="token punctuation">)</span>
print<span class="token punctuation">(</span><span class="token string">"w0={} w1={}"</span>.format<span class="token punctuation">(</span>w0, w1<span class="token punctuation">))</span>
print<span class="token punctuation">(</span><span class="token string">""</span><span class="token punctuation">)</span>
<span class="token comment"># Save the Trained Model</span>
print<span class="token punctuation">(</span><span class="token string">"训练完成..."</span><span class="token punctuation">)</span>
torch.save<span class="token punctuation">(</span>model.state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span>, <span class="token string">'./model.pth'</span><span class="token punctuation">)</span>
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
- 98
- 99
- 100
- 101
- 102
- 103
- 104
- 105
- 106
- 107
- 108
- 109
- 110
- 111
- 112
- 113
- 114
- 115
- 116
- 117
- 118
【推荐阅读】
Pytorch-GPU安装教程大合集(Perfect完美系列)