YOLO初学,主要对YOLO代码进行理解

17 篇文章 1 订阅
12 篇文章 0 订阅

相关知识点:

IOU:https://blog.csdn.net/iamoldpan/article/details/78799857

YOLO1:https://blog.csdn.net/m0_37192554/article/details/81092514

YOLO2:https://blog.csdn.net/lwplwf/article/details/82895409

               https://blog.csdn.net/u011507206/article/details/60884602

               https://blog.csdn.net/jesse_mx/article/details/53925356

YOLO3:https://blog.csdn.net/leviopku/article/details/82660381

对于他们的初步了解我是通过看他们论文以及上述博客进行的。

但是感觉论文写的也看的不是很懂。不是太能掌握整体流程。

 

代码:

这是这边讲解的,写的简单,但是似乎结果不对:https://github.com/18150167970/pytorch-yolov3-modifiy?tdsourcetag=s_pctim_aiomsg

这个比较难:https://github.com/DeNA/PyTorch_YOLOv3

 

前面加载数据,因为暂时对数据不太了解,先没看,主要看网络结构。

model = Darknet(opt.model_config_path)

--》

self.hyperparams, self.module_list = create_modules(self.module_defs)#创建模型

--》

当module_def["type"] != "yolo":按照自己的规则建立,这里其他的不难。

module_def["type"] == "yolo"

yolo_layer = YOLOLayer(anchors, num_classes, img_height)

前面几行代码是配置参数。

anchors总共是9个框,聚类得到的结果。在上述YOLO3博客的框图可以知道为啥每次三个不一样大小。

 

然后到class YOLOLayer(nn.Module)上,这是我们重点想讲解的。

self.bbox_attrs = 5 + num_classes#前5为坐标以及高度框度还有背景的。后面是类别。所以总的文章是85

nA = self.num_anchors  # 3
nB = x.size(0)  # batchsize
nG = x.size(2)  # 13 26 52对应

prediction = x.view(nB, nA, self.bbox_attrs, nG, nG).permute(
            0, 1, 3, 4, 2).contiguous()#把网络输出的结果拆除,x.shape:(1, 255, 13, 13)--》(1, 3, 13, 13, 85)

 

x = torch.sigmoid(prediction[..., 0])  # Center x (1,3,13,13)
y = torch.sigmoid(prediction[..., 1])  # Center y
w = prediction[..., 2]  # Width
h = prediction[..., 3]  # Height
pred_conf = torch.sigmoid(prediction[..., 4])  # bbox的置信度

前五个非类别,最后80为类别
pred_cls = torch.sigmoid(prediction[..., 5:])  # 每个类别的概率

 

grid_x = torch.arange(nG).repeat(nG, 1).view(
            [1, 1, nG, nG]).type(FloatTensor)  # 五个单元左上角坐标x, 橫向加一,纵向相同
grid_y = torch.arange(nG).repeat(nG, 1).t().view(
            [1, 1, nG, nG]).type(FloatTensor)  # 五个单元左上角坐标x, 橫向相同,纵向加一

pred_boxes[..., 0] = x.data + grid_x
pred_boxes[..., 1] = y.data + grid_y

每个像素的偏移。

 

进入重点部分:

找真值图

nGT, nCorrect, mask, conf_mask, tx, ty, tw, th, tconf, tcls = build_targets(
                pred_boxes=pred_boxes.cpu().data,
                pred_conf=pred_conf.cpu().data,
                pred_cls=pred_cls.cpu().data,
                target=targets.cpu().data,
                anchors=scaled_anchors.cpu().data,
                num_anchors=nA,
                num_classes=self.num_classes,
                grid_size=nG,
                ignore_thres=self.ignore_thres,
                img_dim=self.image_dim,
            )

nGT = 0#真值总数
nCorrect = 0#分对的总数

gt_box = torch.FloatTensor(np.array([0, 0, gw, gh])).unsqueeze(0)#torch.Size([1, 4])

            # Get shape of anchor box
anchor_shapes = torch.FloatTensor(np.concatenate(
                (np.zeros((len(anchors), 2)), np.array(anchors)), 1))#torch.Size([3, 4])

anch_ious = bbox_iou(gt_box, anchor_shapes)#聚类的高宽和真值高宽的IOU

 

best_n = np.argmax(anch_ious)#找出IOU最大的,

pred_box = pred_boxes[b, best_n, gj, gi].unsqueeze(0)
            # Masks,用于找到最高重叠率的预测窗口
mask[b, best_n, gj, gi] = 1
conf_mask[b, best_n, gj, gi] = 1

iou = bbox_iou(gt_box, pred_box, x1y1x2y2=False)
pred_label = torch.argmax(pred_cls[b, best_n, gj, gi])
score = pred_conf[b, best_n, gj, gi]
if iou > 0.5 and pred_label == target_label and score > 0.5:
       nCorrect += 1  # 用于计算召回率和准确率

然后就可以计算损失了。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值