Tensorflow2.0---SSD网络原理及代码解析(三)- 特征提取网络

40 篇文章 2 订阅
7 篇文章 2 订阅

Tensorflow2.0—SSD网络原理及代码解析(三)- 特征提取网络

model = SSD300(input_shape, NUM_CLASSES, anchors_size)

这行代码进行SSD特征提取网络的构建。一起来看看内部代码是如何实现的~
首先,先进行VGG16网络的搭建。
在这里插入图片描述
上述就是VGG16网络,用一个dict按照name进行保存。然后return回SSD特征提取网络代码中。接下来,就是对特定的网络层进行处理了~

# 对conv4_3的通道进行l2标准化处理 
    # 38,38,512
    net['conv4_3_norm'] = Normalize(20, name='conv4_3_norm')(net['conv4_3'])
    num_priors = 4
    # 预测框的处理
    # num_priors表示每个网格点先验框的数量,4是x,y,h,w的调整
    net['conv4_3_norm_mbox_loc'] = Conv2D(num_priors * 4, kernel_size=(3,3), padding='same', name='conv4_3_norm_mbox_loc')(net['conv4_3_norm'])
    net['conv4_3_norm_mbox_loc_flat'] = Flatten(name='conv4_3_norm_mbox_loc_flat')(net['conv4_3_norm_mbox_loc'])
    # num_priors表示每个网格点先验框的数量,num_classes是所分的类
    net['conv4_3_norm_mbox_conf'] = Conv2D(num_priors * num_classes, kernel_size=(3,3), padding='same',name='conv4_3_norm_mbox_conf')(net['conv4_3_norm'])
    net['conv4_3_norm_mbox_conf_flat'] = Flatten(name='conv4_3_norm_mbox_conf_flat')(net['conv4_3_norm_mbox_conf'])

    priorbox = PriorBox(input_shape, anchors_size[0], max_size=anchors_size[1], aspect_ratios=[2],
                        variances=[0.1, 0.1, 0.2, 0.2],
                        name='conv4_3_norm_mbox_priorbox')
    net['conv4_3_norm_mbox_priorbox'] = priorbox(net['conv4_3_norm'])

我就以第一个有效特征层进行解释,这是VGG16中conv4_3层,shape为(None,38,38,512)。由于第一层比较浅,所以对它做一个L2标准化操作,conv4_3_norm的shape不变还是(None,38,38,512)。然后进行预测框的生成,在代码实现中体现为feature map的channel的改变。conv4_3_norm_mbox_loc的shape为(None,38,38,16),最后一维的16是表示每个锚点生成4个box,每个box有xywh信息。conv4_3_norm_mbox_conf的shape为(None,38,38,84),最后一维84是表示每个锚点生成4个box,每个box包含21个类别概率。然后都将二者进行flatten操作。
在这里插入图片描述
预测框生成之后,就是要对锚点框进行生成操作~

priorbox = PriorBox(input_shape, anchors_size[0], max_size=anchors_size[1], aspect_ratios=[2],
                        variances=[0.1, 0.1, 0.2, 0.2],
                        name='conv4_3_norm_mbox_priorbox')
net['conv4_3_norm_mbox_priorbox'] = priorbox(net['conv4_3_norm'])

其实这个步骤与我写的Tensorflow2.0—SSD网络原理及代码解析(二)-锚点框的生成代码几乎一样,其实就是进行锚点框的生成。
在这里插入图片描述
结果就是(None,5776,8),表示生成了5776个anchor box,每个box前四个是左上角右下角坐标,后四个是variances[0.1, 0.1, 0.2, 0.2]。同理:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
接着,进行anchor box维度上的拼接:
在这里插入图片描述
最后,进行reshape并进行concat:
在这里插入图片描述
这里。net[‘predictions’]就包含着预测框的loc和conf信息,和锚点框的信息。shape为(None,8732,33),8732表示一共有8732个anchor box。

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

进我的收藏吃灰吧~~

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值