卷积神经网络
假设输入x
的shape为(
b
s
,
c
h
i
n
,
h
i
n
,
w
i
n
bs, ch_{in}, h_{in}, w_{in}
bs,chin,hin,win),卷积核的数量为
c
h
o
u
t
ch_{out}
chout, 长宽为(
k
h
,
k
w
k_h, k_w
kh,kw),padding为(
p
h
,
p
w
p_h, p_w
ph,pw),stride为(
s
h
,
s
w
s_h, s_w
sh,sw),则x
经过该卷积核之后的shape(
b
s
,
c
h
o
u
t
,
h
o
u
t
,
w
o
u
t
bs, ch_{out}, h_{out}, w_{out}
bs,chout,hout,wout)公式可用一下来计算:
1、输入经过卷积核之后的大小计算
这个比较简单,相信大家都能够推算出来,这里不多叙述。
h o u t = h i n + 2 ∗ p h − k h s h + 1 w o u t = w i n + 2 ∗ p w − k w s w + 1 \begin{aligned} h_{out} &= \frac{h_{in} + 2*p_{h} - k_h}{s_h} + 1 \\ w_{out} &= \frac{w_{in} + 2*p_{w} - k_w}{s_w} + 1 \end{aligned} houtwout=shhin+2∗ph−kh+1=swwin+2∗pw−kw+1
2、卷积核参数量计算
首先,我们要明白一个卷积核仅有1个偏执参数,我们先算一个卷积核有多少参数,每个卷积核的通道是和输入通道 c h i n ch_{in} chin相等的,也就是每个卷积核有 k h ∗ k w ∗ c h i n k_h * k_w * ch_{in} kh∗kw∗chin个参数(不含偏执),一共有 c h o u t ch_{out} chout个卷积核,那么就有 k h ∗ k w ∗ c h i n ∗ c h o u t k_h * k_w * ch_{in} * ch_{out} kh∗kw∗chin∗chout个参数(不含偏执),最后加上 c h o u t {ch_{out}} chout个偏执,就得到了总的参数量 p a r a m param param。
p a r a m = k h ∗ k w ∗ c h o u t ∗ c h i n + 1 ∗ c h o u t = ( k h ∗ k w ∗ c h i n + 1 ) ∗ c h o u t \begin{aligned} param &= k_h * k_w * ch_{out} * ch_{in} + 1 * ch_{out} \\ &= (k_h * k_w * ch_{in} + 1) * ch_{out} \end{aligned} param=kh∗kw∗chout∗chin+1∗chout=(kh∗kw∗chin+1)∗chout
3、计算量
F
L
O
P
s
FLOPs
FLOPs 是floating point of operations的缩写,是浮点运算次数,理解为计算量,可以用来衡量算法/模型复杂度。
这里推算我们反着来,从结果看,我们得到了(
c
h
o
u
t
,
h
o
u
t
,
w
o
u
t
ch_{out}, h_{out}, w_{out}
chout,hout,wout)个结果, 我们先计算一个结果的计算量f,之后将其×(
c
h
o
u
t
,
h
o
u
t
,
w
o
u
t
ch_{out}, h_{out}, w_{out}
chout,hout,wout)即可得到全部计算量。
要计算f,我们知道这个结果是(1)由卷积核扫描到的区域与卷积核相乘,(2)将所有结果相加,(3)最后加上偏执得到的。假设扫描到的区域为 n ∗ n ∗ c n*n*c n∗n∗c,卷积核大小也应该为 n ∗ n ∗ c n*n*c n∗n∗c,点积我们可以得到 n ∗ n ∗ c n*n*c n∗n∗c个数,这时我们的乘法计算量为 n ∗ n ∗ c n*n*c n∗n∗c, 将这些数相加需要 n ∗ n ∗ c − 1 n*n*c-1 n∗n∗c−1个加法,最后加上偏执,则计算量需要再+1,将这些计算量相加得到一个结果的计算量;共有( c h o u t , h o u t , w o u t ch_{out}, h_{out}, w_{out} chout,hout,wout)个结果,最后乘以结果的数量就得到了总的计算量。
F L O P s = [ ( c h i n ∗ k h ∗ k w ) + ( c h i n ∗ k h ∗ k w − 1 ) + 1 ] ∗ c h o u t ∗ h o u t ∗ w o u t = 2 ∗ c h i n ∗ k h ∗ k w ∗ c h o u t ∗ h o u t ∗ w o u t \begin{aligned} FLOPs &= [(ch_{in} * k_h * k_w) + (ch_{in} * k_h * k_w - 1) + 1] * ch_{out} * h_{out} * w_{out} \\ &= 2 * ch_{in} * k_h * k_w * ch_{out} * h_{out} * w_{out} \end{aligned} FLOPs=[(chin∗kh∗kw)+(chin∗kh∗kw−1)+1]∗chout∗hout∗wout=2∗chin∗kh∗kw∗chout∗hout∗wout