yolov3选取正负样本

  • 负责预测目标网格中与ground truth的IOU最大的anchor为正样本(记住这里没有阈值的事情,否则会绕晕)
  • 剩下的anchor中,与全部ground truth的IOU都小于阈值的anchor为负样本
  • 其他是忽略样本
  • 代码未完待续
  • 获取正样本代码,参考这里
def calculate_iou(_box_a, _box_b):
		b1_x1, b1_x2 = _box_a[:, 0] - _box_a[:, 2] / 2, _box_a[:, 0] + _box_a[:, 2] / 2
        b1_y1, b1_y2 = _box_a[:, 1] - _box_a[:, 3] / 2, _box_a[:, 1] + _box_a[:, 3] / 2
        b2_x1, b2_x2 = _box_b[:, 0] - _box_b[:, 2] / 2, _box_b[:, 0] + _box_b[:, 2] / 2
        b2_y1, b2_y2 = _box_b[:, 1] - _box_b[:, 3] / 2, _box_b[:, 1] + _box_b[:, 3] / 2
        box_a = torch.zeros_like(_box_a)
        box_b = torch.zeros_like(_box_b)
        box_a[:, 0], box_a[:, 1], box_a[:, 2], box_a[:, 3] = b1_x1, b1_y1, b1_x2, b1_y2
        box_b[:, 0], box_b[:, 1], box_b[:, 2], box_b[:, 3] = b2_x1, b2_y1, b2_x2, b2_y2
        A = box_a.size(0)
        B = box_b.size(0)
        # intersection
        # expand to A*B*2 and compare
        max_xy  = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
        min_xy  = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2))
        # minus and set 0 if result less than 0
        inter   = torch.clamp((max_xy - min_xy), min=0)
        # size:A*B
        inter   = inter[:, :, 0] * inter[:, :, 1]
        area_a = ((box_a[:, 2]-box_a[:, 0]) * (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) 
        area_b = ((box_b[:, 2]-box_b[:, 0]) * (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter)
        union = area_a + area_b - inter
        return inter / union
'''
targets是标签列表,长度是batch_size,元素的shape是(真实框个数*5)
anchors是[[116,90],[156,198],[373,326]]或[[30,61],[62,45],[59,119]]或[[10,13],[16,30],[33,23]]
in_h, in_w是13,13或26,26或52,52
num_classes是类别数,voc是20,COCO是80
'''
def get_target(targets, anchors, in_h, in_w, num_classes):
    bs=len(targets)
    positive=torch.zeros(bs,len(anchors),in_h, in_w, 5+num_classes,requires_grad = False)
    negtive=torch.ones(bs,len(anchors),in_h, in_w, requires_grad = False)
    for b in range(bs):
        batch_target = torch.zeros_like(targets[b])
        # 计算该特征图上标签的值
        batch_target[:, [0,2]] = targets[b][:, [0,2]] * in_w
        batch_target[:, [1,3]] = targets[b][:, [1,3]] * in_h
        batch_target[:, 4] = targets[b][:, 4]
        batch_target = batch_target.cpu()
        # 计算标签和anchor的IOU
        # 这里可以随便选一个共同中心(0,0),根据高宽计算IOU
        gt_box= torch.FloatTensor(torch.cat((torch.zeros((batch_target.size(0), 2)), batch_target[:, 2:4]), 1))
        anchor_shapes=torch.FloatTensor(torch.cat((torch.zeros((len(anchors), 2)), torch.FloatTensor(anchors)), 1))
        iou=calculate_iou(gt_box, anchor_shapes)
        # 获得与标签最匹配的anchor的索引
        best_ns = torch.argmax(iou, dim=-1)
        for t, best_n in enumerate(best_ns):
            # 第t个标签中心所在网格,种类
            i = torch.floor(batch_target[t, 0]).long()
            j = torch.floor(batch_target[t, 1]).long()
            c = batch_target[t, 4].long()
            positive[b,best_n,j,i,0]=batch_target[t, 0] - i.float()
            positive[b,best_n,j,i,1]=batch_target[t, 1] - j.float()
            positive[b,best_n,j,i,2]=math.log(batch_target[t, 2] / anchors[best_n][0])
            positive[b,best_n,j,i,3]=math.log(batch_target[t, 3] / anchors[best_n][1])
            positive[b,best_n,j,i,4]=1
            positive[b,best_n,j,i,c+5]=1
            negtive[b,best_n,j,i]=0
    return positive,negtive
  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
Linux创始人LinusTorvalds有一句名言:Talk is cheap, Show me the code.(冗谈不够,放码过来!)。 代码阅读是从入门到提高的必由之路。尤其对深度学习,许多框架隐藏了神经网络底层的实现,只能在上层调包使用,对其内部原理很难认识清晰,不利于进一步优化和创新。  YOLOv3是一种基于深度学习的端到端实时目标检测方法,以速度快见长。YOLOv3的实现Darknet是使用C语言开发的轻型开源深度学习框架,依赖少,可移植性好,可以作为很好的代码阅读案例,让我们深入探究其实现原理。  本课程将解析YOLOv3的实现原理和源码,具体内容包括: YOLO目标检测原理  神经网络及Darknet的C语言实现,尤其是反向传播的梯度求解和误差计算 代码阅读工具及方法 深度学习计算的利器:BLAS和GEMM GPU的CUDA编程方法及在Darknet的应用 YOLOv3的程序流程及各层的源码解析本课程将提供注释后的Darknet的源码程序文件。  除本课程《YOLOv3目标检测:原理与源码解析》外,本人推出了有关YOLOv3目标检测的系列课程,包括:   《YOLOv3目标检测实战:训练自己的数据集》  《YOLOv3目标检测实战:交通标志识别》  《YOLOv3目标检测:原理与源码解析》  《YOLOv3目标检测:网络模型改进方法》 建议先学习课程《YOLOv3目标检测实战:训练自己的数据集》或课程《YOLOv3目标检测实战:交通标志识别》,对YOLOv3的使用方法了解以后再学习本课程。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

刀么克瑟拉莫

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值