Tensorflow2.0—Centernet网络原理及代码解析(三)- 损失函数的构建

Tensorflow2.0—Centernet网络原理及代码解析(三)- 损失函数的构建

Tensorflow2.0—Centernet网络原理及代码解析(一)- 特征提取网络Tensorflow2.0—Centernet网络原理及代码解析(二)- 数据生成中已经把Centernet网络基本的backbone和数据生成讲完了,还剩一个损失函数的构建~
在train.py中

model = centernet(input_shape, num_classes=num_classes, backbone=backbone, mode='train')

loss_ = Lambda(loss, name='centernet_loss')([y1, y2, y3, hm_input, wh_input, reg_input, reg_mask_input, index_input])

就是损失函数构建的函数~
先来看下这个函数的参数(前三个是通过预测网络得到的预测结果,后五个是通过标签进行编码之后的结果):

#   hm_pred:热力图的预测值       (batch_size, 128, 128, num_classes)
#   wh_pred:宽高的预测值         (batch_size, 128, 128, 2)
#   reg_pred:中心坐标偏移预测值  (batch_size, 128, 128, 2)
#   hm_true:热力图的真实值       (batch_size, 128, 128, num_classes)
#   wh_true:宽高的真实值         (batch_size, max_objects, 2)
#   reg_true:中心坐标偏移真实值  (batch_size, max_objects, 2)
#   reg_mask:真实值的mask        (batch_size, max_objects)
#   indices:真实值对应的坐标     (batch_size, max_objects)

1.热力图的损失函数计算

首先,先对热力图进行损失计算:

hm_loss = focal_loss(hm_pred, #预测热力图,shape=(2,128,128,20)
					 hm_true  #真实编码热力图,shape=(2,128,128,20)
					 )

先来看下热力图的损失函数计算原理:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
代码实现:

	#   heatmap 为 1 的部分是正样本
    pos_mask = tf.cast(tf.equal(hm_true, 1), tf.float32)
    #   其余的为负样本
    neg_mask = tf.cast(tf.less(hm_true, 1), tf.float32)

先计算损失函数中(1-Yxyc)^4部分的常数值(对应于损失函数公式中otherwise公式的第一部分,默认β=4):

    neg_weights = tf.pow(1 - hm_true, 4) # 对应于(1-Yxyc)^4

然后分别计算正样本与负样本的损失函数:

	#  正样本的损失函数计算公式
    pos_loss = -tf.math.log(tf.clip_by_value(hm_pred, 1e-6, 1.)) * tf.pow(1 - hm_pred, 2) * pos_mask
    #  负样本的损失函数计算公式
    neg_loss = -tf.math.log(tf.clip_by_value(1 - hm_pred, 1e-6, 1.)) * tf.pow(hm_pred, 2) * neg_weights * neg_mask

得到的pos_loss和neg_loss的shape均为(2,128,128,20)。

	# 计算正样本的个数
    num_pos = tf.reduce_sum(pos_mask)
    #计算正样本的总热力图损失
    pos_loss = tf.reduce_sum(pos_loss)
    # 计算负样本的总热力图损失
    neg_loss = tf.reduce_sum(neg_loss)

然后,进行损失函数的归一化:

cls_loss = tf.cond(tf.greater(num_pos, 0), lambda: (pos_loss + neg_loss) / num_pos, lambda: neg_loss)

这一行代码相当于:

if num_pos == 0:
	loss = loss - neg_loss # 只有负样本
else:
	loss = loss - (pos_loss + neg_loss) / num_pos

解释:当没有正样本的时候,只取负样本损失的负数,当存在正样本的时候,将正负样本相加最后除正样本的个数即可。

2.宽高与中心偏移的损失函数计算

这两个放在一起讲的原因就是它们都是使用了reg_l1_loss函数。只不过在关于宽高的损失函数的前面加了一个损失系数0.1.

    wh_loss = 0.1 * reg_l1_loss(wh_pred, wh_true, indices, reg_mask)
    reg_loss = reg_l1_loss(reg_pred, reg_true, indices, reg_mask)

reg_l1_loss函数:

def reg_l1_loss(y_pred, y_true, indices, mask):
    #-------------------------------------------------------------------------#
    #   获得batch_size和num_classes
    #-------------------------------------------------------------------------#
    b, c = tf.shape(y_pred)[0], tf.shape(y_pred)[-1]
    k = tf.shape(indices)[1]

    y_pred = tf.reshape(y_pred, (b, -1, c))
    length = tf.shape(y_pred)[1]
    indices = tf.cast(indices, tf.int32)

    #-------------------------------------------------------------------------#
    #   利用序号取出预测结果中,和真实框相同的特征点的部分
    #-------------------------------------------------------------------------#
    batch_idx = tf.expand_dims(tf.range(0, b), 1)
    batch_idx = tf.tile(batch_idx, (1, k))
    full_indices = (tf.reshape(batch_idx, [-1]) * tf.cast(length, tf.int32) +
                    tf.reshape(indices, [-1]))

    y_pred = tf.gather(tf.reshape(y_pred, [-1,c]),full_indices)
    y_pred = tf.reshape(y_pred, [b, -1, c])

    mask = tf.tile(tf.expand_dims(mask, axis=-1), (1, 1, 2))
    #-------------------------------------------------------------------------#
    #   求取l1损失值
    #-------------------------------------------------------------------------#
    total_loss = tf.reduce_sum(tf.abs(y_true * mask - y_pred * mask))
    reg_loss = total_loss / (tf.reduce_sum(mask) + 1e-4)
    return reg_loss
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

进我的收藏吃灰吧~~

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值