代码主要分四个部分
一
- 数据读取与处理
- 网络代码
- loss计算
- 预测解码
以下主要对上面4个部分的关键代码进行详解,同时记录自己当前的一些疑问
1. 数据读取与处理
后面再补
2. 网络代码
后面再补
3. loss计算
本文代码主要简介pytorch版本的yolov3,和源码及ultralytics版本都有所出入,主要去思考其中的细节。
本节主要从三个函数进行讲解- loss主函数 forward:loss计算过程
先贴一张loss公式图
yolov3的loss分为三个组要部分:定位损失,置信度损失,类别损失。
由图中可知边框损失使用mse loss, (2-w * h) 是为了平衡大框和小框带来的影响,yolov1使用宽高的开方进行平衡,类别损失使用二分类交叉熵损失,解除了类别之间的互斥性。置信度仍然使用mse loss计算
# l代表特征尺度的索引
- loss主函数 forward:loss计算过程
bs = input.size(0)
in_h = input.size(2)
in_w = input.size(3)
#计算当前特征层的步长
stride_h = self.input_shape[0] / in_h
stride_w = self.input_shape[1] / in_w
#将anchor缩放到当前特征图尺寸
scales_anchors = [(a_w / stride_w, a_h / stride_h) for a_w, a_h in self.anchors]
prediction = input.view(bs,len(self.anchors_mask[l]), self.bbox_attrs, in_h, in_w).premute(0,1,3,4,2).contiguous()
#获取预测结果,其中x,y,conf,class_conf需要归一化
x = torch.sigmoid(prediction[...,0])
y = torch.sigmoid(prediction[...,1])
w = prediction[...,2]
h = prediction[...,3]
conf = torch.sigmoid(prediction[...,4])
pred_cls = torch.sigmoid(prediction[...,5:])
#样本匹配获取mask
y_true, noobj_mask, box_loss_scale = self.get_target(l, targets, scales_anchors, in_h, in_w