Tensorflow2.0—Centernet网络原理及代码解析(三)- 损失函数的构建
在Tensorflow2.0—Centernet网络原理及代码解析(一)- 特征提取网络和Tensorflow2.0—Centernet网络原理及代码解析(二)- 数据生成中已经把Centernet网络基本的backbone和数据生成讲完了,还剩一个损失函数的构建~
在train.py中
model = centernet(input_shape, num_classes=num_classes, backbone=backbone, mode='train')
中
loss_ = Lambda(loss, name='centernet_loss')([y1, y2, y3, hm_input, wh_input, reg_input, reg_mask_input, index_input])
就是损失函数构建的函数~
先来看下这个函数的参数(前三个是通过预测网络得到的预测结果,后五个是通过标签进行编码之后的结果):
# hm_pred:热力图的预测值 (batch_size, 128, 128, num_classes)
# wh_pred:宽高的预测值 (batch_size, 128, 128, 2)
# reg_pred:中心坐标偏移预测值 (batch_size, 128, 128, 2)
# hm_true:热力图的真实值 (batch_size, 128, 128, num_classes)
# wh_true:宽高的真实值 (batch_size, max_objects, 2)
# reg_true:中心坐标偏移真实值 (batch_size, max_objects, 2)
# reg_mask:真实值的mask (batch_size, max_objects)
# indices:真实值对应的坐标 (batch_size, max_objects)
1.热力图的损失函数计算
首先,先对热力图进行损失计算:
hm_loss = focal_loss(hm_pred, #预测热力图,shape=(2,128,128,20)
hm_true #真实编码热力图,shape=(2,128,128,20)
)
先来看下热力图的损失函数计算原理:
代码实现:
# heatmap 为 1 的部分是正样本
pos_mask = tf.cast(tf.equal(hm_true, 1), tf.float32)
# 其余的为负样本
neg_mask = tf.cast(tf.less(hm_true, 1), tf.float32)
先计算损失函数中(1-Yxyc)^4部分的常数值(对应于损失函数公式中otherwise公式的第一部分,默认β=4):
neg_weights = tf.pow(1 - hm_true, 4) # 对应于(1-Yxyc)^4
然后分别计算正样本与负样本的损失函数:
# 正样本的损失函数计算公式
pos_loss = -tf.math.log(tf.clip_by_value(hm_pred, 1e-6, 1.)) * tf.pow(1 - hm_pred, 2) * pos_mask
# 负样本的损失函数计算公式
neg_loss = -tf.math.log(tf.clip_by_value(1 - hm_pred, 1e-6, 1.)) * tf.pow(hm_pred, 2) * neg_weights * neg_mask
得到的pos_loss和neg_loss的shape均为(2,128,128,20)。
# 计算正样本的个数
num_pos = tf.reduce_sum(pos_mask)
#计算正样本的总热力图损失
pos_loss = tf.reduce_sum(pos_loss)
# 计算负样本的总热力图损失
neg_loss = tf.reduce_sum(neg_loss)
然后,进行损失函数的归一化:
cls_loss = tf.cond(tf.greater(num_pos, 0), lambda: (pos_loss + neg_loss) / num_pos, lambda: neg_loss)
这一行代码相当于:
if num_pos == 0:
loss = loss - neg_loss # 只有负样本
else:
loss = loss - (pos_loss + neg_loss) / num_pos
解释:当没有正样本的时候,只取负样本损失的负数,当存在正样本的时候,将正负样本相加最后除正样本的个数即可。
2.宽高与中心偏移的损失函数计算
这两个放在一起讲的原因就是它们都是使用了reg_l1_loss函数。只不过在关于宽高的损失函数的前面加了一个损失系数0.1.
wh_loss = 0.1 * reg_l1_loss(wh_pred, wh_true, indices, reg_mask)
reg_loss = reg_l1_loss(reg_pred, reg_true, indices, reg_mask)
reg_l1_loss函数:
def reg_l1_loss(y_pred, y_true, indices, mask):
#-------------------------------------------------------------------------#
# 获得batch_size和num_classes
#-------------------------------------------------------------------------#
b, c = tf.shape(y_pred)[0], tf.shape(y_pred)[-1]
k = tf.shape(indices)[1]
y_pred = tf.reshape(y_pred, (b, -1, c))
length = tf.shape(y_pred)[1]
indices = tf.cast(indices, tf.int32)
#-------------------------------------------------------------------------#
# 利用序号取出预测结果中,和真实框相同的特征点的部分
#-------------------------------------------------------------------------#
batch_idx = tf.expand_dims(tf.range(0, b), 1)
batch_idx = tf.tile(batch_idx, (1, k))
full_indices = (tf.reshape(batch_idx, [-1]) * tf.cast(length, tf.int32) +
tf.reshape(indices, [-1]))
y_pred = tf.gather(tf.reshape(y_pred, [-1,c]),full_indices)
y_pred = tf.reshape(y_pred, [b, -1, c])
mask = tf.tile(tf.expand_dims(mask, axis=-1), (1, 1, 2))
#-------------------------------------------------------------------------#
# 求取l1损失值
#-------------------------------------------------------------------------#
total_loss = tf.reduce_sum(tf.abs(y_true * mask - y_pred * mask))
reg_loss = total_loss / (tf.reduce_sum(mask) + 1e-4)
return reg_loss