Tensorflow2.0—Centernet网络原理及代码解析(一)- 特征提取网络
目前,目标检测已经被anchor base类型的网络所承包了,anchor free类型的算法已经越来越少了,最近看大佬的blogTensorflow2 搭建自己的Centernet目标检测平台,就想自己尝试学习下anchor free的内容,先把代码过一遍,然后再总结下原理。
这篇blog,就先把它的特征提取网络看一遍吧~~
在centernet中,backbone的选择性比较多,一般常见的有Hourglass Network、DLANet或者Resnet,在代码中有两个选择:Hourglass,Resnet。
# 获取centernet模型
model = centernet(input_shape, num_classes=num_classes, backbone=backbone, mode='train')
在train.py中这一行代码是进行centernet的特征提取网络的创建。
创建一系列INPUT,每个的含义为:
# hm_true:热力图的真实值 (batch_size, 128, 128, num_classes)
# wh_true:宽高的真实值 (batch_size, max_objects, 2)
# reg_true:中心坐标偏移真实值 (batch_size, max_objects, 2)
# reg_mask:真实值的mask (batch_size, max_objects)
# indices:真实值对应的坐标 (batch_size, max_objects)
一、Resnet
https://blog.csdn.net/weixin_44791964/article/details/113682561?spm=1001.2014.3001.5501
C5 = ResNet50(image_input)
最终返回的上图中红色部分,也就是backbone部分,返回的C5的shape为(None,16,16,2048)。然后将C5喂进center head中。
x = Dropout(rate=0.5)(x)
#-------------------------------#
# 解码器
#-------------------------------#
num_filters = 256
# 16, 16, 2048 -> 32, 32, 256 -> 64, 64, 128 -> 128, 128, 64
for i in range(3):
# 进行上采样
x = Conv2DTranspose(num_filters // pow(2, i), (4, 4), strides=2, use_bias=False, padding='same',
kernel_initializer='he_normal',
kernel_regularizer=l2(5e-4))(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
先对C5进行随机Dropout,目的是避免模型过拟合,然后对输入进行三次连续的上采样,对应于上图中绿色部分。
然后,针对进行上采样之后的x分别对其进行维度上面的调整,最终得到的y1 = (None,128,128,20),
y2 = (None,128,128,2),y3 = (None,128,128,2)。y1-3所包含的意义为:
y1:热力图的预测值 (batch_size, 128, 128, num_classes) y2:宽高的预测值 (batch_size, 128, 128, 2) y3:中心坐标偏移预测值 (batch_size, 128, 128, 2)
二、hourglass
有时间再写吧~~