论文:https://arxiv.org/abs/2003.03836
代码:https://github.com/PRIS-CV/PMG-Progressive-Multi-Granularity-Training
1 摘要
细粒度分类(Fine-Grained visual classification,FGVC)由于目标类内变化大,分类的难度更大。近期解决细粒度分类的主要思路是找出图像中最有判别力的区域或者是最互补的区域,或者是使用多粒度的特征进行分类。作者提出了一个新的思路,使用跨尺度的特征融合进行细粒度分类。具体的做法是:
- 使用渐进式训练策略,每一轮迭代都是在原有更小尺度特征的基础上添加新的网络层进行模型训练;
- 使用图像分块再混合拼接的思路构建训练图像。
作者的思路取得了和当前SOTA算法相当的分类效果。
2 核心思想
作者提出的方法提取的特征和其他算法的对比如下图所示:
论文的总体核心思想如下图所示,训练网络时迫使网络的浅层去学习细粒度的特征,然后使网络的更深层逐渐转移注意力在更大的粒度上去学习更加抽象的特征。
2.1 网络结构
作者提出的思想可以用在任何的特征提取骨干网络上。
使用 F F F表示特征提取器,其包含 L L L个Stage。网络中间阶段输出的feature map可以表示成 F l ∈ R H l × W l × C l F^l \in R^{H_l \times W_l \times C_l} Fl∈RHl×Wl×Cl,其中 H l , W l , C l H_l,W_l,C_l Hl,Wl,Cl表示第 l l l个stage输出的feature map的高、宽和channel数, l ∈ { 1 , 2 , ⋯ , L } l \in \{1,2,\cdots,L\} l∈{1,2,⋯,L}。
训练的目标是减少基于不同stage输出的feature map进行分类的交叉熵损失。在F的基础上,使用卷积块
H
c
o
n
v
l
H_{conv}^l
Hconvl,以第
l
l
l个stage输出的feature map
F
l
F^l
Fl为输入,将其变换为向量表示
V
l
=
H
c
o
n
v
l
(
F
l
)
V^l = H_{conv}^l(F^l)
Vl=Hconvl(Fl)。而后再使用由两个使用了BN和Elu的全连接层组成的分类模块
H
c
l
a
s
s
l
H_{class}^l
Hclassl,预测第
l
l
l个stage预测的分类结果,即
y
l
=
H
c
l
a
s
s
l
(
V
l
)
y^l = H_{class}^l(V^l)
yl=Hclassl(Vl)。这里,我们考虑了最后的S个stage,即
l
=
L
,
L
−
1
,
⋯
,
L
−
S
+
1
l = L,L-1,\cdots,L-S+1
l=L,L−1,⋯,L−S+1,将这些stage的输出拼接起来,得到:
然后使用一个额外的分类模块,得到
y
c
o
n
c
a
t
=
H
c
l
a
s
s
c
o
n
c
a
t
(
V
c
o
n
c
a
t
)
y_{concat} = H_{class}^{concat}(V_{concat})
yconcat=Hclassconcat(Vconcat)。
2.2 渐进式训练
渐进式训练即先训练网络的低层部分,然后不停的添加新的stage。由于网络低层的表示能力很有限,这就迫使网络的低层stage去学习有判别力的局部特征。相比于整体进行网络训练时同时进行多粒度的训练,渐进式训练使得网络逐渐学习到由局部到全局的特征信息。
因为是分类任务,作者使用的是交叉熵损失,对于每一个中间stage,损失为:
对于拼接后的特征,损失为:
作者使用的渐进式训练策略,起到了对不同粒度特征进行融合的目的,实现思想即一个batch的训练图像先用来训练低层网络,然后使用优化过的低层网络层训练网络的更高层,具体实现代码如下所示:
# Step 1
optimizer.zero_grad()
inputs1 = jigsaw_generator(inputs, 8)
output_1, _, _, _ = netp(inputs1)
loss1 = CELoss(output_1, targets) * 1
loss1.backward()
optimizer.step()
# Step 2
optimizer.zero_grad()
inputs2 = jigsaw_generator(inputs, 4)
_, output_2, _, _ = netp(inputs2)
loss2 = CELoss(output_2, targets) * 1
loss2.backward()
optimizer.step()
# Step 3
optimizer.zero_grad()
inputs3 = jigsaw_generator(inputs, 2)
_, _, output_3, _ = netp(inputs3)
loss3 = CELoss(output_3, targets) * 1
loss3.backward()
optimizer.step()
# Step 4
optimizer.zero_grad()
_, _, _, output_concat = netp(inputs)
concat_loss = CELoss(output_concat, targets) * 2
concat_loss.backward()
optimizer.step()
2.3 拼图
单独的使用渐进式训练对细粒度分类是不利的,因为这会使得学到的多粒度特征关注相似的目标区域。作者提出了使用随机的图像拼接构建多粒度的训练图像,在每一个step的训练时使用对应粒度的训练图像,即鼓励网络学习指定粒度的特征。
给定输入图像 d ∈ R 3 × W × H d \in R^{3 \times W \times H} d∈R3×W×H,将其分为 n × n n \times n n×n个patch,每一个patch大小为 3 × W n × H n 3 \times \frac{W}{n} \times \frac{H}{n} 3×nW×nH。W和H应该为n的整数倍。随后,这些patch随机打乱再拼接成新的训练图像 P ( d , n ) P(d,n) P(d,n)。因此,patch的粒度由超参数 n n n进行控制。
每一个stage的超参数 n n n的选择由两个因素决定:
- patch的大小应该小于当前stage的感受野的大小;
- patch的大小应该随着stage的感受野的增大而增大。
一般情况下,一个stage的感受野的大小应该为前一个stage的感受野的两倍,所以,这里设置第 l l l个stage的 n = 2 L − l + 1 n = 2^{L - l + 1} n=2L−l+1。
训练过程中,在每一个step,先根据该stage所在的顺序计算其对应的patch的大小,即 n = 2 L − l + 1 n = 2^{L - l + 1} n=2L−l+1,patch大小为 3 × W n × H n 3 \times \frac{W}{n} \times \frac{H}{n} 3×nW×nH。然后按照patch进行图像拼图,得到 P ( d , n ) P(d,n) P(d,n),将 P ( d , n ) P(d,n) P(d,n)和对应的类别信息 y y y组成pair进行模型训练。
拼图策略不能够保证小于patch的目标的完整,但对于模型训练来说,这不一定是坏事。因为在模型训练时,一般都会首先进行图像随机裁剪处理,因此每一轮迭代时使用的拼接图的切分位置和上一轮迭代时的切分位置是不同的。因此,这就迫使模型要更能够学习到指定粒度的更有判别力的特征。
2.4 推理
推理时可以只使用concat的预测结果,即:
也可以使用多个stage和concat预测结果的加权和,即:
4 实验效果
4.1 对比其他算法
4.2 消融研究
不使用拼图,即各个step都使用原图进行训练,不同stage数量时的分类结果:
stage越多,效果越好,但并非stage越大越好,如stage数量大于3之后,分类准确率开始下降。这是因为网络的浅层主要学习和类别不相干的特征,如果stage太大,让网络的前面层去学习和类别相关的特征反倒不利于整体的分类。这个实验说明了渐进式训练的重要性。
使用拼图,不同step使用不同粒度的训练图像,结果如下表所示:
表3相对于表2,在相同S值时,不同的n值可以看出,使用对应粒度的拼接图像进行训练是有好处的。但是如果n值太大,图像别切分的太碎,一个patch内包含的有用信息就太少了,对模型训练是不利的。
4.3 可视化
使用grad-cam可视化分类的关键区域,如下图所示:
可以看出作者的方法中不同的stage关注的特征的粒度不同,由小到大变化。相比于参考模型,作者的模型关注到了不同区域之间的关联关系,而参考模型则只关注到了图像的局部区域。
5 代码
jigsaw_generator即将原始图像等分为小碎块,然后再随机拼接成新的训练图像:
def jigsaw_generator(images, n):
l = []
for a in range(n):
for b in range(n):
l.append([a, b])
block_size = 448 // n
rounds = n ** 2
random.shuffle(l)
jigsaws = images.clone()
for i in range(rounds):
x, y = l[i]
temp = jigsaws[..., 0:block_size, 0:block_size].clone()
jigsaws[..., 0:block_size, 0:block_size] = jigsaws[..., x * block_size:(x + 1) * block_size,
y * block_size:(y + 1) * block_size].clone()
jigsaws[..., x * block_size:(x + 1) * block_size, y * block_size:(y + 1) * block_size] = temp
return jigsaws
渐进式训练:
# Step 1
optimizer.zero_grad()
inputs1 = jigsaw_generator(inputs, 8)
output_1, _, _, _ = netp(inputs1)
loss1 = CELoss(output_1, targets) * 1
loss1.backward()
optimizer.step()
# Step 2
optimizer.zero_grad()
inputs2 = jigsaw_generator(inputs, 4)
_, output_2, _, _ = netp(inputs2)
loss2 = CELoss(output_2, targets) * 1
loss2.backward()
optimizer.step()
# Step 3
optimizer.zero_grad()
inputs3 = jigsaw_generator(inputs, 2)
_, _, output_3, _ = netp(inputs3)
loss3 = CELoss(output_3, targets) * 1
loss3.backward()
optimizer.step()
# Step 4
optimizer.zero_grad()
_, _, _, output_concat = netp(inputs)
concat_loss = CELoss(output_concat, targets) * 2
concat_loss.backward()
optimizer.step()