从零开始玩转目标检测模型CenterNet
前言
“真正的东西,是朴素的,也是优雅的”,这句话用来形容CenterNet绝不为过。笔者参考论文和官方源码,抽取目标检测的精华部分,致力朴素易懂,使用pytorch重新构建了一遍,并添加了注释。
本文以代码作为切入点,需要了解CenterNet原理,推荐参考扔掉anchor!真正的CenterNet——Objects as Points论文解读
本内容的github地址,官方源码可见文末的参考文献。
模型搭建
模型结构:resnet18+上采样+3个header输出 (图来自原论文)
在原resnet.py__init__()函数中添加了如下6行代码,self.layer5至self.layer7是上采样操作,self.hm,self.wh,self.reg为模型的3个输出Header,分别为类别关键点的heatmap图,长宽的回归,缩放坐标偏移
self.layer4 = self._make_layer(block, 512, layers[3], stride=2) # /32
#上方为torchvision自带源码
self.layer5 = self._make_deconv_layer(512, 256) # /16
self.layer6 = self._make_deconv_layer(256, 128) # /8
self.layer7 = self._make_deconv_layer(128, 64) # /4
self.hm = self._make_header_layer(64, num_classes) # heatmap
self.wh = self._make_header_layer(64, 2) # width and height
self.reg = self._make_header_layer(64, 2) # regress offset
上采样函数如下,官方源码在上采样时卷积使用的是DeformableConvolutionalNetworks,笔者为了运行方便,就直接使用了传统卷积
# add for upsample
def _make_deconv_layer(self, in_ch, out_ch):
deconv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, 1, 1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
return deconv
用于生成header的函数如下
# add for three headers
def _make_header_layer(self, in_ch, out_ch):
header = nn.Sequential(
nn.Conv2d(in_ch, in_ch, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(in_ch, out_ch, 1, 1)
)
return header
数据集构建
SeaShips数据集:数据集共有7000张图片,图片分辨率均为1920x1080,分为六类船只(数据地址)。先放一张模型训练72个epoch的测试图。
数据初始化
Dataset是标准的torch格式,在__getitem__函数中,list_bbox_cls为[(bbox1,cls1),(bbox2,cls2),],real_w, real_h为原图片的宽和高,down_ratio为下采样倍数(默认4);heatmap_size为模型最终输出heatmap图大小(默认128);hm,wh和reg上文已经介绍过;max_objs为一张图片内可能包含最大的目标数(轮船在一张图中比较少,默认32),ind为目标关键点在二维heatmap中对应的一维索引,reg_mask为目标mask数组,是否包含目标0/1;
class CTDataset(Dataset):
def __init__(self, opt, data, transform=None):
'''
数据集构建
:param opt: 配置参数
:param data: [(img_path,[(bbox1,cls1),(bbox2,cls2),])..] bbox(左上右下)
:param transform:
'''
self.images = data
self.opt = opt
self.transform = transform
def __getitem__(self, index):
img_path, list_bbox_cls = self.images[index]
img = Image.open(img_path)
real_w, real_h = img.size
if self.transform: img = self.transform(img)
heatmap_size = self.opt.input_size // self.opt.down_ratio
# heatmap
hm = np.zeros((self.opt.num_classes, heatmap_size, heatmap_size), dtype=np.float32)
# withd and hight
wh = np.zeros((self.opt.max_objs, 2), dtype=np.float32)
# regression
reg = np.zeros((self.opt.max_objs, 2), dtype=np.float32)
# index in 1D heatmap
ind = np.zeros((self.opt.max_objs), dtype=np.int)
# 1=there is a target in the list 0=there is not
reg_mask = np.zeros((self.opt.max_objs), dtype=np.uint8)
# get the absolute ratio
w_ratio = self.opt.input_size / real_w / self.opt.down_ratio
h_ratio = self.opt.input_size / real_h / self.opt.down_ratio
for i, (bbox, cls) in enumerate(list_bbox_cls):
# original bbox size -> heatmap bbox size
bbox = bbox[0] * w_ratio, bbox[1] * h_ratio, bbox[2] * w_ratio, bbox[3] * h_ratio
width, height = bbox[2] - bbox[0], bbox[3] - bbox[1]
# center point(x,y)
center = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3])