CenterNet 笔记注释

参考论文

简介

关于CenterNet 的介绍网上的文章很多,这篇论文和代码前前后后看了好几遍,但是始终不得其精髓,究其原因我觉得是我只是单纯的知道这个网络的流程,内部的具体操作与实现不清晰,导致总是稀里糊涂。今天再看一次代码,终于有了更深刻的理解,这里对我的理解做一个简单记录。

heat map 的生成

# radius:当前目标的高斯半经,center:当前目标的中心, heatmap: 初始化的时候都为0
def draw_gaussian(heatmap, center, radius, k=1):
    diameter = 2 * radius + 1
    # 生成heat map
    gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6)
    x, y = int(center[0]), int(center[1])
    height, width = heatmap.shape[0:2]
    
    # 正常情况下,left, right = radius, radius+1,top, bottom = radius, radius+1
    # 只有当 radius> x 和 radius>y时才有一点变化,比如在边界处,主要是为了保证不越界
    left, right = min(x, radius), min(width - x, radius + 1)
    top, bottom = min(y, radius), min(height - y, radius + 1)
    
    # 在输出的特征图上找到当前目标以x,y为中心所在的位置
    masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
    
    # 这里其实就是把gaussian直接给masked_gaussian。
    # [radius - top:radius + bottom, radius - left:radius + right]里的内容就是为了保证不会越界
    # masked_heatmap.shape = masked_gaussian.shape
    masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right]
    if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0:  # TODO debug
        np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
        # 输出值为masked_heatmap所在的heatmap区域,就是把高斯区域值放到目标所在的位置

    return heatmap

gaussian2D 的作用就是 生成heat map,具体的生成方法看下面的例子

import numpy as np
def gaussian2D(shape, sigma=1):
    m, n = [(ss - 1.) / 2. for ss in shape]
    y, x = np.ogrid[-m:m + 1, -n:n + 1]
    print(y,x)
    print((x * x + y * y))
    h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
    h[h < np.finfo(h.dtype).eps * h.max()] = 0
    return h
    
h = gaussian2D([5,5])
print(h)

输出:
[[-2.]
 [-1.]
 [ 0.]
 [ 1.]
 [ 2.]] [[-2. -1.  0.  1.  2.]]
[[8. 5. 4. 5. 8.]
 [5. 2. 1. 2. 5.]
 [4. 1. 0. 1. 4.]
 [5. 2. 1. 2. 5.]
 [8. 5. 4. 5. 8.]]
[[0.01831564 0.082085   0.13533528 0.082085   0.01831564]
 [0.082085   0.36787944 0.60653066 0.36787944 0.082085  ]
 [0.13533528 0.60653066 1.         0.60653066 0.13533528]
 [0.082085   0.36787944 0.60653066 0.36787944 0.082085  ]
 [0.01831564 0.082085   0.13533528 0.082085   0.01831564]]

这里生产的heat map中只有该点值为1它才是关键点,也就是正样本,其他的点都是负样本,就算这个值是接近1。

Generator 生成训练样本时都返回什么玩意

class Generator(object):
    def __init__(self,batch_size,train_lines,val_lines,input_size,output_size,num_classes,max_objects=100):
        self.batch_size = batch_size
        self.train_lines = train_lines
        self.val_lines = val_lines
        self.input_size = input_size
        self.output_size = output_size
        self.num_classes = num_classes
        self.max_objects = max_objects
	def generate(self, train=True):
        while True:
            if train:
                # 打乱
                shuffle(self.train_lines)
                lines = self.train_lines
            else:
                shuffle(self.val_lines)
                lines = self.val_lines
                
            batch_images = np.zeros((self.batch_size, self.input_size[0], self.input_size[1], self.input_size[2]), dtype=np.float32)
            batch_hms = np.zeros((self.batch_size, self.output_size[0], self.output_size[1], self.num_classes), dtype=np.float32)
            batch_whs = np.zeros((self.batch_size, self.max_objects, 2), dtype=np.float32)
            batch_regs = np.zeros((self.batch_size, self.max_objects, 2), dtype=np.float32)
            batch_reg_masks = np.zeros((self.batch_size, self.max_objects), dtype=np.float32)
            batch_indices = np.zeros((self.batch_size, self.max_objects), dtype=np.float32)
            
            b = 0
            for annotation_line in lines:  
                img,y=self.get_random_data(annotation_line,self.input_size[0:2])

                if len(y)!=0:
                    boxes = np.array(y[:,:4],dtype=np.float32)
                    boxes[:,0] = boxes[:,0]/self.input_size[1]*self.output_size[1]
                    boxes[:,1] = boxes[:,1]/self.input_size[0]*self.output_size[0]
                    boxes[:,2] = boxes[:,2]/self.input_size[1]*self.output_size[1]
                    boxes[:,3] = boxes[:,3]/self.input_size[0]*self.output_size[0]

                for i in range(len(y)):
                    bbox = boxes[i].copy()
                    bbox = np.array(bbox)
                    bbox[[0, 2]] = np.clip(bbox[[0, 2]], 0, self.output_size[1] - 1)
                    bbox[[1, 3]] = np.clip(bbox[[1, 3]], 0, self.output_size[0] - 1)
                    cls_id = int(y[i,-1])
                    
                    h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
                    if h > 0 and w > 0:
                        ct = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2], dtype=np.float32)
                        # 中心坐标取整
                        ct_int = ct.astype(np.int32)    
                        # 针对每一个目标都会生成一个高斯半经
                        radius = gaussian_radius((math.ceil(h), math.ceil(w)))
                        radius = max(0, int(radius))
                        # 获得热力图
                        batch_hms[b, :, :, cls_id] = draw_gaussian(batch_hms[b, :, :, cls_id], ct_int, radius)
                        # 第i个目标的wh
                        batch_whs[b, i] = 1. * w, 1. * h
                        # 第i个目标的中心偏移量
                        batch_regs[b, i] = ct - ct_int
                        # 第i个目标的mask设置为1,用于排除多余的0,假定一张图中共有100个目标
                        batch_reg_masks[b, i] = 1
                        # 表示第ct_int[1] 即(h)行的第ct_int[0] 即(w)个。把当前二维features拉成一维形式,当前中心点在一维数组中的位置。
                        batch_indices[b, i] = ct_int[1] * self.output_size[0] + ct_int[0]

                # 将RGB转化成BGR,归一化
                img = np.array(img,dtype = np.float32)[:,:,::-1]
                batch_images[b] = preprocess_image(img)
                b = b + 1
                if b == self.batch_size:
                    b = 0
                    # 返回的内容:图片,hm图,wh,偏移量,目标mask,以及在一维数组中的位置
                    yield [batch_images, batch_hms, batch_whs, batch_regs, batch_reg_masks, batch_indices], np.zeros((self.batch_size,))

                    batch_images = np.zeros((self.batch_size, self.input_size[0], self.input_size[1], 3), dtype=np.float32)

                    batch_hms = np.zeros((self.batch_size, self.output_size[0], self.output_size[1], self.num_classes),
                                        dtype=np.float32)
                    batch_whs = np.zeros((self.batch_size, self.max_objects, 2), dtype=np.float32)
                    batch_regs = np.zeros((self.batch_size, self.max_objects, 2), dtype=np.float32)
                    batch_reg_masks = np.zeros((self.batch_size, self.max_objects), dtype=np.float32)
                    batch_indices = np.zeros((self.batch_size, self.max_objects), dtype=np.float32)

损失计算

def loss(args):
    #-----------------------------------------------------------------------------------------------------------------#
    # hm_pred:热力图的预测值       (self.batch_size, self.output_size[0], self.output_size[1], self.num_classes)
    # wh_pred:宽高的预测值         (self.batch_size, self.output_size[0], self.output_size[1], 2)
    # reg_pred:中心坐标偏移预测值  (self.batch_size, self.output_size[0], self.output_size[1], 2)
    # hm_true:热力图的真实值       (self.batch_size, self.output_size[0], self.output_size[1], self.num_classes)
    # wh_true:宽高的真实值         (self.batch_size, self.max_objects, 2)
    # reg_true:中心坐标偏移真实值  (self.batch_size, self.max_objects, 2)
    # reg_mask:真实值的mask        (self.batch_size, self.max_objects)
    # indices:真实值对应的坐标     (self.batch_size, self.max_objects)
    #-----------------------------------------------------------------------------------------------------------------#
    hm_pred, wh_pred, reg_pred, hm_true, wh_true, reg_true, reg_mask, indices = args
    hm_loss = focal_loss(hm_pred, hm_true)
    wh_loss = 0.1 * reg_l1_loss(wh_pred, wh_true, indices, reg_mask)
    reg_loss = reg_l1_loss(reg_pred, reg_true, indices, reg_mask)
    total_loss = hm_loss + wh_loss + reg_loss
    # total_loss = tf.Print(total_loss,[hm_loss,wh_loss,reg_loss])
    return total_loss

def focal_loss(hm_pred, hm_true):
    # 找到正样本
    pos_mask = tf.cast(tf.equal(hm_true, 1), tf.float32)
    # 小于1的都是负样本
    neg_mask = tf.cast(tf.less(hm_true, 1), tf.float32)
    neg_weights = tf.pow(1 - hm_true, 4)
    
    pos_loss = -tf.math.log(tf.clip_by_value(hm_pred, 1e-6, 1.)) * tf.pow(1 - hm_pred, 2) * pos_mask
    neg_loss = -tf.math.log(tf.clip_by_value(1 - hm_pred, 1e-6, 1.)) * tf.pow(hm_pred, 2) * neg_weights * neg_mask

    num_pos = tf.reduce_sum(pos_mask)
    pos_loss = tf.reduce_sum(pos_loss)
    neg_loss = tf.reduce_sum(neg_loss)

    # tf.cond 就是if else 
    # tf.greater 比较两个值的大小,如果有正样本则计算正负两个的损失,没有则只计算负样本的损失
    cls_loss = tf.cond(tf.greater(num_pos, 0), lambda: (pos_loss + neg_loss) / num_pos, lambda: neg_loss)
    return cls_loss
    
# 计算wh和偏移量的公式与代码是一样的,采用L1损失函数,只是多了一个系数
def reg_l1_loss(y_pred, y_true, indices, mask):
    b, c = tf.shape(y_pred)[0], tf.shape(y_pred)[-1]
    # k : 最大预测100个目标
    k = tf.shape(indices)[1]
    # y_pred.shape = (self.batch_size, self.output_size[0], self.output_size[1], 2)
    # label 的 shape = (self.batch_size, self.max_objects, 2),所以先要reshape
    y_pred = tf.reshape(y_pred, (b, -1, c))
    # 预测出的目标个数
    length = tf.shape(y_pred)[1]
    # indices.shape = (b,100)
    indices = tf.cast(indices, tf.int32)
    
    batch_idx = tf.expand_dims(tf.range(0, b), 1)
    batch_idx = tf.tile(batch_idx, (1, k))
    # batch_idx.shape = (batch,100)

    ##  找到其在1维上的索引
    # batch_idx = 
    # [[0 0 0 ... 0 0 0]
    # [ 1 1 1 ... 1 1 1]
    # [ 2 2 2 ... 2 2 2]
    # [ 3 3 3 ... 3 3 3]
    # ...
    # [ b-1 b-1 b-1 ... b-1 b-1 b-1]]
    # 总共batch个图片,每个图片预测100目标,每个图片预测的总目标数是 length, 
    # y_pred 共预测 b*length个点。
    # 这里实际上每张图取了前100个点的位置,后面的都丢弃
    full_indices = (tf.reshape(batch_idx, [-1]) *tf.cast(length,tf.int32) +
                    tf.reshape(indices, [-1]))
    
    # 取出对应的预测值
    y_pred = tf.gather(tf.reshape(y_pred, [-1,c]),full_indices)
    y_pred = tf.reshape(y_pred, [b, -1, c])
    # y_pred.shape = [b, 100, 2]

    # mask.shape = (b,100,2)
    mask = tf.tile(tf.expand_dims(mask, axis=-1), (1, 1, 2))
    # 求取l1损失值, y_true.shape =  (b,100,2)
    total_loss = tf.reduce_sum(tf.abs(y_true * mask - y_pred * mask))
    reg_loss = total_loss / (tf.reduce_sum(mask) + 1e-4)
    return reg_loss

这里计算损失的代码与公示完全对应:
L = − 1 N ∑ x y c { ( 1 − Y ^ x y c ) α log ⁡ ( Y ^ x y c ) if  Y x y c = 1 ( 1 − Y x y c ) β ( Y ^ x y c ) α log ⁡ ( 1 − Y ^ x y c ) otherwise  L=\frac{-1}{N}\sum_{xyc} \left\{ \begin{aligned} (1- \hat{Y}_{xyc})^\alpha \log(\hat{Y}_{xyc}) &&\text{if } Y_{xyc} =1\\ (1-Y_{xyc})^\beta (\hat{Y}_{xyc})^\alpha \log(1-\hat{Y}_{xyc}) &&\text{otherwise } \end{aligned}\right. L=N1xyc{(1Y^xyc)αlog(Y^xyc)(1Yxyc)β(Y^xyc)αlog(1Y^xyc)if Yxyc=1otherwise 
其中, α \alpha α β \beta β 是Focal Loss 的超参数, N N N 是图像 I I I 的关键点数量,就是一副图片中目标的个数),在这篇论文 α = 2 \alpha = 2 α=2 β = 4 \beta = 4 β=4 ,这个损失函数是 Focal Loss 的修改版,适用于 CenterNet。

公示说明 详细分析见这里
1. 对于正样本 (只有高斯生成的hm图点为1才是正关键点)

对于 easy example 的中心点,适当减少其训练比重也就是 loss 值,当 Y x y c = 1 Y_{xyc} =1 Yxyc=1的时候,
( 1 − Y ^ x y c ) α (1- \hat{Y}_{xyc})^\alpha (1Y^xyc)α 就充当了矫正的作用, 假如 Y ^ x y c \hat{Y}_{xyc} Y^xyc 接近 1 的话,说明这个是一个比较容易检测出来的点,那么 ( 1 − Y ^ x y c ) α (1- \hat{Y}_{xyc})^\alpha (1Y^xyc)α 就相应比较低了。而当 Y ^ x y c \hat{Y}_{xyc} Y^xyc 接近0 的时候,说明这个中心点还没有学习到,所以要加大其训练的比重,因此 ( 1 − Y ^ x y c ) α (1- \hat{Y}_{xyc})^\alpha (1Y^xyc)α 就会很大。

2. 对于非正样本 (hm图中不是1的点都是负样本)

这里对实际中心点的其他近邻点的训练比重(loss)也进行了调整,当otherwise的时候 ( Y ^ x y c ) α (\hat{Y}_{xyc})^\alpha (Y^xyc)α 的预测值应为0, 如果不为0的且越来越接近1的话, ( Y ^ x y c ) α (\hat{Y}_{xyc})^\alpha (Y^xyc)α 的值就会变大从而使这个损失的训练比重也加大;而 ( 1 − Y x y c ) β (1-Y_{xyc})^\beta (1Yxyc)β 则对中心点周围的,和中心点靠得越近的点也做出了调整(因为与实际中心点靠的越近的点可能会影响干扰到实际中心点,造成误检测),因为 Y x y c Y_{xyc} Yxyc 是一个高斯核生成的中心点,在中心点 Y x y c = 1 Y_{xyc} =1 Yxyc=1, 但是在中心点周围扩散 Y x y c Y_{xyc} Yxyc 会由1慢慢变小但是并不是直接为0,因此 ( 1 − Y x y c ) β (1-Y_{xyc})^\beta (1Yxyc)β 与中心点距离越近, Y x y c Y_{xyc} Yxyc 接近1,这个值越小,相反则越大。

总结

这样看过论文后,再对着代码看一篇,这个网络的核心动心就非常清楚了。ok,暂且就这样。

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值