SSD代码解读之四——loss

相关参数

loss的调用过程如下。

ssd_net.losses(logits, localisations,
                           b_gclasses, b_glocalisations, b_gscores,
                           match_threshold=FLAGS.match_threshold,
                           negative_ratio=FLAGS.negative_ratio,
                           alpha=FLAGS.loss_alpha,
                           label_smoothing=FLAGS.label_smoothing)

在考虑loss函数的实现过程之前,需要先把各个参数的含义和shape弄清楚。

  • logits:ssd网络的输出之一,是网络对每个default box的分类的预测值(未经过softmax计算概率)。值为一个list,每个feat_layer对应一个元素(共6个:[‘block4’, ‘block7’, ‘block8’, ‘block9’, ‘block10’, ‘block11’])。各个元素的shape为 (n, h, w, num_anchors_i, num_classes),其中n是batch_size,h和w是特征图的高和宽,num_anchors_i是该层特征图上每个点对应的default box的数量,值为4或6, num_classes是类别数量(标签类别数加1,背景为0)。
  • localisations:ssd网络的输出之一,为网络对每个default box坐标的偏移值的预测(cy, cx, h, w)。也是一个list,每个feat_layer对应一个元素。各个元素的shape为(n, h, w, num_anchors_i, 4)。
  • b_gclasses:ground truth值之一,为经过ssd_net.bboxes_encode编码处理的label。bboxes_encode得到的是一个list,每个feat_layer对应一个元素,每个元素的shape为(h, w, num_anchors_i)。经过tf.train.batch和tf_utils.reshape_list的处理后得到的是一个list,每个元素的shape为(batch_size, h, w, num_anchors_i)
  • b_glocalisations:ground truth值之一,与b_gclasses类似,是一个list,各个元素的shape为(batch_size, h, w, num_anchors_i, 4)
  • b_gscores:ground truth值之一,与b_gclasses类似, 是一个list,shape与b_gclasses相同,是把default box认定为b_gclasses对应位置的标签的得分,实际上就是该default box与ground truth中重合率最大的标签框的IOU。每个元素的shape为(batch_size, h, w, num_anchors_i)。
  • match_threshold:针对gscore的阈值,默认为0.5。若gscores > match_threshold,则认为该default box为正例
  • negative_ratio:负例与正例的比例,默认为3。做Hard negative mining使用。
  • alpha:用来均衡分类loss和定位loss的因数,默认为1
  • label_smoothing:未使用

实现过程

计算原理已经在前面文章分析过了,主要考虑代码实现

def ssd_losses(logits, localisations,
               gclasses, glocalisations, gscores,
               match_threshold=0.5,
               negative_ratio=3.,
               alpha=1.,
               label_smoothing=0.,
               device='/cpu:0',
               scope=None):
    with tf.name_scope(scope, 'ssd_losses'):
        lshape = tfe.get_shape(logits[0], 5)
        num_classes = lshape[-1] 
        batch_size = lshape[0] 

        # Flatten out all vectors!
        flogits = []
        fgclasses = []
        fgscores = []
        flocalisations = []
        fglocalisations = []
		#每个feat_layer的h、w和num_anchor_i是不一样的,这里需要按层把所有default box的数据合并到一起
        for i in range(len(logits)): #logits的len为6,每个feat_layer对应一个元素
        	#把一个feat_layer的所有default box并列到一起
            flogits.append(tf.reshape(logits[i], [-1, num_classes])) 
            fgclasses.append(tf.reshape(gclasses[i], [-1]))
            fgscores.append(tf.reshape(gscores[i], [-1]))
            flocalisations.append(tf.reshape(localisations[i], [-1, 4]))
            fglocalisations.append(tf.reshape(glocalisations[i], [-1, 4]))
        #把所有feat_layer的default box并列到一起
        logits = tf.concat(flogits, axis=0)  #(8732, num_classes),所有feat_layer总共8732个default box
        gclasses = tf.concat(fgclasses, axis=0)  #(8732,)
        gscores = tf.concat(fgscores, axis=0)  #(8732,)
        localisations = tf.concat(flocalisations, axis=0) #(8732, 4)
        glocalisations = tf.concat(fglocalisations, axis=0) #(8732, 4)
        dtype = logits.dtype

        # Compute positive matching mask...
        pmask = gscores > match_threshold #正例的mask
        fpmask = tf.cast(pmask, dtype)
        n_positives = tf.reduce_sum(fpmask) #正例的数量

        # Hard negative mining...
        no_classes = tf.cast(pmask, tf.int32) #背景处的pmask值为0,正好是背景的标签值
        predictions = slim.softmax(logits) #各个default box的预测值, (8732, num_classes)
        nmask = tf.logical_and(tf.logical_not(pmask), #gscores <= match_threshold
                               gscores > -0.5) #gscores肯定大于0,为什么这么写没搞懂。。
        fnmask = tf.cast(nmask, dtype)
        nvalues = tf.where(nmask, #背景处为prediction,其余为1
                           predictions[:, 0], #class 0: background,即背景类的预测概率
                           1. - fnmask) # nmask为0的地方fnmask也为0,因此非背景处nvalues值为1
        nvalues_flat = tf.reshape(nvalues, [-1]) #(8732 * num_classes,)
        # Number of negative entries to select.
        max_neg_entries = tf.cast(tf.reduce_sum(fnmask), tf.int32) #背景的实际数量
        n_neg = tf.cast(negative_ratio * n_positives, tf.int32) + batch_size #负例的数量为正例数量 × ratio + batch_size
        n_neg = tf.minimum(n_neg, max_neg_entries) #避免实际数量小于计算得到的数量

        val, idxes = tf.nn.top_k(-nvalues_flat, k=n_neg) #这里先取负,相当于取值最接近0的k个prediction(prediction越接近0,说明预测的越离谱,所谓的hard negative)
        max_hard_pred = -val[-1] #k个prediction值中最大的
        # Final negative mask.
        nmask = tf.logical_and(nmask, nvalues < max_hard_pred) #缩小nmask的范围,只取prediction最接近0的k个
        fnmask = tf.cast(nmask, dtype)

        # Add cross-entropy loss.
        with tf.name_scope('cross_entropy_pos'):
            loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                                  labels=gclasses)
            loss = tf.div(tf.reduce_sum(loss * fpmask), batch_size, name='value')
            tf.losses.add_loss(loss)

        with tf.name_scope('cross_entropy_neg'):
            loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                                  labels=no_classes)
            loss = tf.div(tf.reduce_sum(loss * fnmask), batch_size, name='value')
            tf.losses.add_loss(loss)

        # Add localization loss: smooth L1, L2, ...
        with tf.name_scope('localization'):
            # Weights Tensor: positive mask + random negative.
            weights = tf.expand_dims(alpha * fpmask, axis=-1)
            loss = custom_layers.abs_smooth(localisations - glocalisations)
            loss = tf.div(tf.reduce_sum(loss * weights), batch_size, name='value')
            tf.losses.add_loss(loss)
  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值