Paddledetection python端训练进行类别加权

因为实验样本的不平衡,用类别加权提高map

在paddledetection动态图上:

mask_head.py

    def get_loss(self, mask_logits, mask_label, mask_target, mask_weight):
        #print("mask_labelinput", mask_label)
        my_weight = [1.0, 0.2, 0.2, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,1.0]   #11类的权重,顺序对应数据集
        new_weight=[]
        for i in mask_label:                                                  #选择对应的标签的权重
            new_weight.append(my_weight[i])

        new_weight = paddle.to_tensor(new_weight, dtype='float32')
        new_weight = new_weight.reshape([len(new_weight), 1, 1])            #增加维度,不然后面广播

        mask_label = F.one_hot(mask_label, self.num_classes).unsqueeze([2, 3])
        mask_label = paddle.expand_as(mask_label, mask_logits)
        mask_label.stop_gradient = True
        mask_pred = paddle.gather_nd(mask_logits, paddle.nonzero(mask_label))
        shape = mask_logits.shape
        mask_pred = paddle.reshape(mask_pred, [shape[0], shape[2], shape[3]])
        mask_target = mask_target.cast('float32')
        mask_weight = mask_weight.unsqueeze([1, 2])

        # print("my_weight",my_weight)
        # print("mask_label", mask_label.shape)
        # print("mask_logits", mask_logits.shape)
        # print("mask_pred",mask_pred.shape)
        # print("mask_target", mask_target.shape)
        loss_mask = F.binary_cross_entropy_with_logits(
            mask_pred, mask_target, weight=mask_weight,pos_weight=new_weight, reduction="mean")  #,pos_weight=new_weight
        return loss_mask

bbox_head.py


    def get_loss(self, scores, deltas, targets, rois, bbox_weight):
        """
        scores (Tensor): scores from bbox head outputs
        deltas (Tensor): deltas from bbox head outputs
        targets (list[List[Tensor]]): bbox targets containing tgt_labels, tgt_bboxes and tgt_gt_inds
        rois (List[Tensor]): RoIs generated in each batch
        """
        #print("scores",scores.shape)
        #print("deltas", deltas.shape)
        cls_name = 'loss_bbox_cls'
        reg_name = 'loss_bbox_reg'
        loss_bbox = {}

        # TODO: better pass args
        tgt_labels, tgt_bboxes, tgt_gt_inds = targets

        #print("tgt_bboxes", tgt_bboxes)
        #print("tgt_gt_inds", tgt_gt_inds)

        # bbox cls
        tgt_labels = paddle.concat(tgt_labels) if len(
            tgt_labels) > 1 else tgt_labels[0]
        valid_inds = paddle.nonzero(tgt_labels >= 0).flatten()
        if valid_inds.shape[0] == 0:
            loss_bbox[cls_name] = paddle.zeros([1], dtype='float32')
        else:
            tgt_labels = tgt_labels.cast('int64')
            tgt_labels.stop_gradient = True
            my_weight =paddle.to_tensor([1.0,0.2,0.2,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0], dtype='float32')
            #my_weight= my_weight*5
            #print("tgt_labels", tgt_labels)
            #print("scores", scores)

            loss_bbox_cls = F.cross_entropy(
                input=scores, label=tgt_labels,weight=my_weight, reduction='mean')
            loss_bbox[cls_name] = loss_bbox_cls

        # bbox reg

        cls_agnostic_bbox_reg = deltas.shape[1] == 4
        fg_inds = paddle.nonzero(
            paddle.logical_and(tgt_labels >= 0, tgt_labels <
                               self.num_classes)).flatten()
        #print('tgt_labels', tgt_labels)
        #print('fg_inds',fg_inds)
        my_box_weight = paddle.to_tensor([1.0, 0.2, 0.2, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], dtype='float32')

        if fg_inds.numel() == 0:
            loss_bbox[reg_name] = paddle.zeros([1], dtype='float32')
            return loss_bbox

        if cls_agnostic_bbox_reg:
            reg_delta = paddle.gather(deltas, fg_inds)
        else:
            fg_gt_classes = paddle.gather(tgt_labels, fg_inds)
            reg_row_inds = paddle.arange(fg_gt_classes.shape[0]).unsqueeze(1)
            reg_row_inds = paddle.tile(reg_row_inds, [1, 4]).reshape([-1, 1])

            reg_col_inds = 4 * fg_gt_classes.unsqueeze(1) + paddle.arange(4)
            reg_col_inds = reg_col_inds.reshape([-1, 1])
            reg_inds = paddle.concat([reg_row_inds, reg_col_inds], axis=1)

            reg_delta = paddle.gather(deltas, fg_inds)
            reg_delta = paddle.gather_nd(reg_delta, reg_inds).reshape([-1, 4])
        rois = paddle.concat(rois) if len(rois) > 1 else rois[0]
        tgt_bboxes = paddle.concat(tgt_bboxes) if len(
            tgt_bboxes) > 1 else tgt_bboxes[0]

        reg_target = bbox2delta(rois, tgt_bboxes, bbox_weight)
        reg_target = paddle.gather(reg_target, fg_inds)
        reg_target.stop_gradient = True

        my_labels = paddle.gather(tgt_labels, fg_inds)                                          #预测标签
        my_box_weight = paddle.gather(my_box_weight, my_labels)                                #标签对应权重
        my_box_weight = my_box_weight.reshape([len(my_box_weight), 1])                           #矩阵转置

        #print('my_labels', my_labels)
        #print('my_box_weight',my_box_weight)


        if self.bbox_loss is not None:
            reg_delta = self.bbox_transform(reg_delta)
            reg_target = self.bbox_transform(reg_target)
            #print("reg_delta", reg_delta)
            #print("reg_target", reg_target)

            loss_bbox_reg = self.bbox_loss(
                reg_delta, reg_target).sum() / tgt_labels.shape[0]
            loss_bbox_reg *= self.num_classes
        else:
            #print("reg_delta", reg_delta.shape)
            #print("reg_target", reg_target.shape)
            my_loss_bbox_reg = paddle.abs(reg_delta - reg_target).sum(
            ) / tgt_labels.shape[0]
            #print("reg_delta - reg_target",reg_delta - reg_target)

            loss_bbox_reg = paddle.abs((reg_delta - reg_target)*my_box_weight).sum(                   #加权采样
            ) / tgt_labels.shape[0]
            #print("loss_bbox_reg",loss_bbox_reg)

        loss_bbox[reg_name] = loss_bbox_reg

        return loss_bbox

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值