论文链接:NAS-FPN: Learning Scalable Feature Pyramid Architecture for Object Detection
跨尺度连接,允许模型将高层具有丰富语义信息等特征和底层具有高分辨率的特征结合起来
本文基于retinanet,使用强化学习的方法来采样不同的融合策略,搜索出最好的fpn
搜索空间结构
强化学习
训练一个控制器,在给定的搜索空间中使用增强学习的方法选择最好的模型结构 。控制器利用 搜索空间中子模型的准确率最为奖励信息来更新参数。
基本空间
-
搜索的基本单位cell (NASnet 也是提出两个cell)
-
一个FPN包含N个不同的merging cell,每个cell来对任何两个输入特征层进行操作,并输出操作后的新特征。
-
一个merging cell中 所有的特征层都有相同的数量的filters
-
使用backbone的输出作为输入{C3,C4,C5,C6,C7},stride of {8,16,32,64,128} 。C6,C7通过对C5应用stride 2 和 stride 4 的最大池化获得
-
金字符网络 输出{P3,P4,P5,P6,P7}
merging cell
controller RNN 来构建merging cell 。
controller RNN选择两个候选的特征层和一个二值操作,来通过操作结合两个特征生成新的特征层。这些特征层可能有不同的分辨率。
每个merging cell 有4 个预测步骤:
1、从候选特征里面选择一个特征层hi
2、无重复的选择另一个特征层hj
3、选择输出特征的分辨率
4、选择二值操作,结合hi和hj,生成输出特征层
新生成的特征层放入候选list中
-
设计了两个二值操作。sum 和global pooling
-
在二值操作之前,输入层需要通过最近邻插值或者最大池化法来到达和输出层一样的分辨率。merged feature layers 后面跟着一个relu,一个3x3的卷积层和一个BN层(R-C-B)
-
为了减少在搜索网络结构中的计算量,我们在第三步的时候 避免选择stride 8的特征。输出{P3, P4, P5, P6, P7}
-
每个输出层通过重复1、2、4步来生成
merging cell 代码
class SumCell(MergingCell):
def _binary_op(self, x1, x2):
return x1 + x2
class GPCell(MergingCell):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
def _binary_op(self, x1, x2):
x2_att = self.global_pool(x2).sigmoid()
return x2 + x2_att * x1
细节
-
基于retinanet框架,TPUs 64images in a batch,multi scale 训练,随机尺寸比例[0.8,1.2]的输出图片尺寸,BN层应用在所有的卷积层之后
-
设置focal loss 超参 α = 0.25 α = 0.25 α=0.25和 γ = 1.5 γ = 1.5 γ=1.5
-
weight decay of 0.0001
momentum of 0.9.
50个epoch -
初始学习率0.08 ,稳定在前30个epochs,30-40epochs 衰减率0.1
-
在COCO2017 数据集上进行训练
-
Proxy task。为了加速RNN controller,需要Proxy task。Proxy task 需要很短的训练时间,并且与真实任务相关。在训练期间,Proxy task 能确定一个好的FPN结构。训练Proxy task 仅需要10个epochs ,而不是训练retinanet的50个epochs为了进一步加速Proxy task的训练速度,使用更小的resnet10 ,图片输入512*512。Proxy task重复3次 pyramid networks
学习率 0.08 前8个epoch,之后衰减率0.1,从cocotrain 2017中随机选择7392张图片 设置为验证集,来获得奖励信息。 -
Controller 是RNN,采用PPO优化算法。Controller 采样多个不同网络结构的子网络。这些网络结构被训练于Proxy task,使用 a pool of works
Result
实验包含了100个 TPUs
验证集的AP 被用于reward信息来更新Controller
叠加普通的FPN,性能没有改变。但是叠加NAS-FPN性能会提升
搜索到的结果
结构代码
来自mmdetection
# add NAS FPN connections
self.fpn_stages = nn.ModuleList()
for _ in range(self.stack_times):
stage = nn.ModuleDict()
# gp(p6, p4) -> p4_1
stage['gp_64_4'] = GPCell(out_channels, norm_cfg=norm_cfg)
# sum