空间金字塔池化(Spatial Pyramid Pooling)方法关联了不定尺寸输出的卷积层和固定大小的全连接层,一方面可以适应不同尺寸图片输入,避免了统一图片大小的前处理操作;另一方面可以提取不同尺寸的空间特征信息,进而提升模型对于空间布局和物体变形的鲁棒性。SPP的基本原理请参考原论文或相关解读,本文基于输入输出尺寸,分析SPP关键参数例如窗口尺寸(kernel)、步长(stride)及边距(padding)的计算方法。
1. 问题提出
已知卷积后输出尺寸 ( w , h ) (w, h) (w,h),空间金字塔池化后目标输出 ( n w , n h ) (n_w, n_h) (nw,nh),计算池化层的窗口尺寸 ( k w , k h ) (k_w, k_h) (kw,kh),步长 ( s w , s h ) (s_w, s_h) (sw,sh)及边距 ( p w , p h ) (p_w, p_h) (pw,ph)。为了简化描述,以下仅基于其中一个维度计算,另一维度采用完全相同的计算公式。因此,相应参数简化为:
已知输入、输出尺寸 w w w和 n n n,求池化窗口尺寸 ( k ) (k) (k),步长 ( s ) (s) (s)及边距 ( p ) (p) (p)。
如果正向计算,公式为:
n = ⌊ w + 2 p − k s ⌋ + 1 (1) n = \left \lfloor \frac {w+2p-k} {s} \right \rfloor + 1 \tag{1} n=⌊sw+2p−k⌋+1(1)
其中 ⌊ x ⌋ \lfloor x \rfloor ⌊x⌋ 表示对 x x x向下取整,例如 ⌊ 1.5 ⌋ = 1 \lfloor 1.5 \rfloor = 1 ⌊1.5⌋=1,同理向上取整符号及例子: ⌈ 1.5 ⌉ = 2 \lceil 1.5 \rceil = 2 ⌈1.5⌉=2。
2. 原始论文公式
原论文中的计算公式:
k = ⌈ w n ⌉ , s = ⌊ w n ⌋ , p = 0 (2) k = \left \lceil \frac {w} {n} \right \rceil, \quad s = \left \lfloor \frac {w} {n} \right \rfloor, \quad p = 0 \tag{2} k=⌈nw⌉,s=⌊nw⌋,p=0(2)
有博文指出了以上公式的问题:
取 w = 7 , n = 4 w=7, n=4 w=7,n=4,根据公式(2)得出 k = 2 , s = 1 , p = 0 k=2,s=1,p=0 k=2,s=1,p=0,然而将池化参数带入公式(1)却得出与输入矛盾的结果: n = 5 n=5 n=5。
实际上,这是作者为论文中特定场景提出的,确实并不具备(作者也没主张)其通用性。
3. 初步修正的公式
参考博文,给出了如下通用性更好的公式:
k = s = ⌈ w n ⌉ p = ⌊ k ∗ n − w + 1 2 ⌋ (3) k = s = \left \lceil \frac {w} {n} \right \rceil \\\\ p = \left \lfloor \frac {k*n-w+1} {2} \right \rfloor \tag{3} k=s=⌈nw⌉p=⌊2k∗n−w+1⌋(3)
对于上一个例子:
w = 7 , n = 4 w=7, n=4 w=7,n=4,根据公式(3)可以得出正确结果: k = 2 , s = 2 , p = 1 k=2,s=2,p=1 k=2,s=2,p=1。
这个公式适用于绝大多数场合,但还是可以找到有问题的例子:
取 w = 5 , n = 4 w=5, n=4 w=5,n=4,根据公式(3)得出 k = 2 , s = 2 , p = 2 k=2,s=2,p=2 k=2,s=2,p=2,
代入公式(1)验证输出尺寸没问题,但是 pytorch要求 padding 不超过 kernel 的一半 即 k > = 2 p k >= 2p k>=2p,显然此处不满足。
4. 可行域分析
为了方便分析这个问题,先排除两种特殊情况:
-
当 n > w n>w n>w 时,不符合SPP的物理意义
-
当 n = 1 n=1 n=1 即输出为1时,取窗口正好为输入尺寸: k = w , s = 1 , p = 0 k=w, s=1, p=0 k=w,s=1,p=0
于是在 w ≥ n > 1 w \geq n \gt 1 w≥n>1 条件下,列出以下限制条件/不等式:
( n − 1 ) ∗ s + k − w ≤ 2 p < n ∗ s + k − w (4-1) (n-1)*s+k-w \leq 2p \lt n*s+k-w \tag{4-1} (n−1)∗s+k−w≤2p<n∗s+k−w(4-1)
0 ≤ 2 p ≤ k (4-2) 0 \leq 2p \leq k \tag{4-2} 0≤2p≤k(4-2)
1 ≤ s ≤ k ≤ w (4-3) 1 \leq s \leq k \leq w \tag{4-3} 1≤s≤k≤w(4-3)
( n − 1 ) ∗ s + k ≥ w + p (4-4) (n-1)*s+k \geq w + p \tag{4-4} (n−1)∗s+k≥w+p(4-4)
其中,
- 不等式(4-1)直接从等式(1)去掉取整符号得到;
- 不等式(4-2)避免引入过多无意义的边距信息,也是 pytorch 中的一个限制;
- 不等式(4-3)要求步长不大于窗口大小,否则跳过了有效区域;
- 不等式(4-4)左边表示池化操作的实际作用范围,右边表示特征图的有效位置,因此整个式子要求池化操作覆盖所有有效区域。
将不等式(4-1)左半部分取整得到 p p p 的计算公式:
p = ⌈ ( n − 1 ) ∗ s + k − w 2 ⌉ (5) p = \left \lceil \frac {(n-1)*s + k - w} {2} \right \rceil \tag{5} p=⌈2(n−1)∗s+k−w⌉(5)
上式代入 k = s k=s k=s 即可得到公式(3)计算 p p p 的部分,表明上式更具一般性,公式(3)的 p p p 只是公式(5)的一个特例。
结合(4-1)左半部分和(4-2)右半部分:
k ≥ 2 p ≥ ( n − 1 ) ∗ s + k − w = > s ≤ w n − 1 k \geq 2p \geq (n-1)*s+k-w => s \leq \frac {w} {n-1} k≥2p≥(n−1)∗s+k−w=>s≤n−1w
结合(4-1)右半部分和(4-3):
0 ≤ 2 p < n ∗ s + k − w ≤ n ∗ k + k − w = > k > w n + 1 0 \leq 2p \lt n*s+k-w \leq n*k+k-w => k \gt \frac {w} {n+1} 0≤2p<n∗s+k−w≤n∗k+k−w=>k>n+1w
不等式(4-4)缩放一下去掉 p p p:
( n − 1 ) ∗ s + k ≥ w (n-1)*s+k \geq w (n−1)∗s+k≥w
综合得到:
{ 1 ≤ s ≤ w n − 1 w n + 1 < k ≤ w k ≥ s k ≥ ( 1 − n ) ∗ s + w (6) \begin{cases} 1 \leq s \leq \frac {w} {n-1} \\\\ \frac {w} {n+1} \lt k \leq w \\\\ k \geq s \\\\ k \geq (1-n)*s + w \end{cases} \tag{6} ⎩⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎨⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎧1≤s≤n−1wn+1w<k≤wk≥sk≥(1−n)∗s+w(6)
注意各个参数都是非负整数,但此刻先不做区分,直接线性规划求解可行域,得到下图。
显然,解可能不唯一。我们先得到一个特征点 P 0 ( w / n , w / n ) P_0(w/n, w/n) P0(w/n,w/n),然后基于不同的策略有不同的选择:
-
如果沿着绿色箭头方向往 P 1 P_1 P1 方向走,窗口大小始终与步长相等,即传统的池化模式。
-
如果沿着青色箭头方向往 P 2 P_2 P2方向走,窗口大小始终大于步长,即带重叠模式的池化。
以 P 1 P_1 P1 方向为例,因为 k , s k,s k,s 都是正整数,我们取 P 0 P_0 P0 右侧最接近的正整数值,即 k = s = ⌈ w / n ⌉ k = s = \lceil w / n \rceil k=s=⌈w/n⌉ ,于是得到了网上常见的初步修正的公式,即上文的公式(3)。
至此,可以统一前文提及的计算方法,并且解释以下两个问题:
(a)公式(3)在什么情况下不再适用?
对照可行域图就很好解释了——绿色线段上可能不存在整数解。
例如 w = 5 , n = 4 w=5,n=4 w=5,n=4,绿色线段两个端点的 s s s 坐标分别为 1.25 和 1.667,二者之间并不存在正整数。
那么,公式(3)在什么条件下才适用呢?令 w = a ∗ n + b w=a*n+b w=a∗n+b,其中 0 ≤ b < n 0 \leq b \lt n 0≤b<n,则
w n − 1 = a ∗ n + b n − 1 = a + a + b n − 1 \frac {w} {n-1} = \frac {a*n+b} {n-1} = a + \frac {a+b} {n-1} n−1w=n−1a∗n+b=a+n−1a+b
显然, w / ( n − 1 ) w/(n-1) w/(n−1) 的整数部分至少达到 a + 1 a+1 a+1 即 ( a + b ) / ( n − 1 ) ≥ 1 (a+b)/(n-1) \geq 1 (a+b)/(n−1)≥1 时,绿色线段标注的可行域上才有整数解:
⌊ w n ⌋ + ( w m o d n ) + 1 ≥ n \left \lfloor \frac {w} {n} \right \rfloor + \left(w \mod n\right) + 1 \geq n ⌊nw⌋+(wmodn)+1≥n
进一步考虑端点上的情况,即上式取等号,此时 w / ( n − 1 ) w/(n-1) w/(n−1) 恰好为整数,且 k = s = w / ( n − 1 ) k=s=w/(n-1) k=s=w/(n−1),参考公式(5)可知:
p = ⌈ n ∗ k − w 2 ⌉ = ⌈ w 2 ∗ ( n − 1 ) ⌉ = ⌈ k 2 ⌉ p = \left \lceil \frac {n*k - w} {2} \right \rceil = \left \lceil \frac {w} {2 *(n-1)} \right \rceil = \left \lceil \frac {k} {2} \right \rceil p=⌈2n∗k−w⌉=⌈2∗(n−1)w⌉=⌈2k⌉
结合 k ≥ 2 p k \geq 2p k≥2p 的限定条件,此时要求 k k k 即 w / ( n − 1 ) w/(n-1) w/(n−1) 必须为偶数。
综上,公式(3)的使用条件:
t = ⌊ w n ⌋ + ( w m o d n ) + 1 t = \left \lfloor \frac {w} {n} \right \rfloor + \left(w \mod n\right) + 1 t=⌊nw⌋+(wmodn)+1
t > n ∨ ( t = n ∧ w / ( n − 1 ) m o d 2 = 0 ) (7) t \gt n \quad \lor \quad \left( t = n \quad \land \quad w/(n-1) \mod 2 = 0 \right) \tag{7} t>n∨(t=n∧w/(n−1)mod2=0)(7)
其中, ∨ \lor ∨ 和 ∧ \land ∧ 分别表示“或”和“且”。
(b)如何处理公式(3)不适用的情况?
当 w , n w,n w,n 不满足不等式(7)时,公式(3)失效,那就走 P 2 P_2 P2 的路线,如青色箭头所示:
-
此种情况下往右显然不存在可行的 s s s了,于是向左一步得到 P 0 P_0 P0 附近的 s s s;
-
然后向上增大 k k k 直到满足可行域要求。
以上过程反映了公式(2)的思路,但是为了更具通用性,确定 k k k 时需要检查是否落在可行域内。将公式(2)中 s s s 的表达式代入(4-4)的缩放式得到 k k k,然后将 k , s k,s k,s 代入公式(5)计算 p p p,最终得到公式(2)的更一般形式:
s = ⌊ w n ⌋ , k = w − ( n − 1 ) ∗ s , p = 0 (8) s = \left \lfloor \frac {w} {n} \right \rfloor, \quad k = w - (n-1)*s, \quad p = 0 \tag{8} s=⌊nw⌋,k=w−(n−1)∗s,p=0(8)
回到 w = 5 , n = 4 w=5,n=4 w=5,n=4 的例子,代入上式得到 k = 2 , s = 1 , p = 0 k=2, s=1, p=0 k=2,s=1,p=0,满足所有约束。
注意:上式和公式(2)的最直接区别是 k k k 的计算方法。公式(2)在定义 k , s k,s k,s 的同时强行设定 p = 0 p=0 p=0(或者说忽略了 p p p 的计算),实际上三者是相互关联的。公式(8)通过构造 k k k,使 p = 0 p=0 p=0 自然得到满足。
5. 完整公式
已知输入、输出尺寸 w w w 和 n n n( w ≥ n w \geq n w≥n),求SPP池化层的窗口尺寸 ( k ) (k) (k),步长 ( s ) (s) (s)及边距 ( p ) (p) (p)。
(1)当满足不等式(7)时,
s = k = ⌈ w n ⌉ , p = ⌈ n ∗ k − w 2 ⌉ s = k = \left \lceil \frac {w} {n} \right \rceil, \quad p = \left \lceil \frac {n*k - w} {2} \right \rceil s=k=⌈nw⌉,p=⌈2n∗k−w⌉
(2)当不满足不等式(7)时,
s = ⌊ w n ⌋ , k = w − ( n − 1 ) ∗ s , p = 0 s = \left \lfloor \frac {w} {n} \right \rfloor, \quad k = w - (n-1)*s, \quad p = 0 s=⌊nw⌋,k=w−(n−1)∗s,p=0
注意:以上算法优先选择传统非重叠的池化方式,只有在无法满足时,才考虑重叠的池化方式。如果倾向于重叠的池化方式,则直接选择第(2)部分计算公式即可。