正常求loss需要两个参数就好了,一个是真实标注的信息,一个是预测的结果。
求是否包含物体的loss, 求真实框与预测框的loss,求类别信息的loss。
因为特征层一共有三层,所以先分别求每一层特征层的loss再讲他们的总和加起来。
当输入一个特征时,可以根据特征来获取框的x, y, w, h,置信度,每一个类别的概率等信息
然后根据获取到的这些信息来对先验框进行调整得到预测框
有了预测框之后,就可以求真实框与预测框的loss
有了每一个类别的概率之后就可以求类别的loss
有了置信度就可以求框是否包含物体的loss
不完整代码:
import torch
import torch.nn as nn
import numpy as np
class YoloLoss(nn.Module):
def __init__(self, anchors, num_classes, input_shape):
"""
anchors:先验框
num_clsses:物体种类个数
"""
self.anchors = anchors
self.num_classes = num_classes
self.input_shape = input_shape
super(YoloLoss, self).__init__()
def forward(self, l, feat, boxes=None, y_true=None):
"""
l:表示第 l层, l=0/1/2
feat: 表示第 l层feature map
boxes:表示真实框和类别信息
y_true:是将boxes的有用信息复制到全零矩阵上(也就是背景为0), 在计算loss的时候用到的
"""
# 1.feat是一个【bs,3*(nc+5),h,w)的特征,根据feat来获取box和cls信息
bs, feat_h, feat_w = feat.size(0), feat.size(2), feat.size(3)
# 2.计算输入的宽高与feat的宽高的比例,求出步长strides
stride_h, stride_w = self.input_shape[0]/feat_h, self.input_shape[1]/feat_w
scaled_anchors = [(w/stride_w, h/stride_h) for (w, h) in self.anchors]
# 3.获取feat里面的信息,先把feat变换维度
prediction = feat.view(bs, 3, self.num_classes+5, feat_h, feat_w).permute(0, 1, 3, 4, 2).contiguous()
# 4.获取x,y,w,h,c......等信息
x = torch.sigmoid(prediction[..., 0])
y = torch.sigmoid(prediction[..., 1])
w = torch.sigmoid(prediction[..., 2])
h = torch.sigmoid(prediction[..., 3])
conf = torch.sigmoid(prediction[..., 4])
pre_cls = torch.sigmoid(prediction[..., 5:])
# 5.对先验框进行调整,获取到预测框
pred_boxes = self.get_pred_boxes(l, x, y, w, h, boxes, scaled_anchors, feat_h, feat_w)
pass
def get_pred_boxes(self, l, x, y, w, h, boxes, scaled_anchors, feat_h, feat_w):
"""这个函数就是对先验框进行调整"""
bs = len(boxes)
# anchors只是先验框的宽高,要获取先验框的话需要先生成网格,根据网格再生成先验框
# repeat函数:(2,3)--repeat(1,2)-->(2,6); (2,3)--repeat(1,2,1)-->(1,2,6)
grid_x = torch.linspace(0, feat_w-1, feat_w).repeat(feat_h, 1).repaet(bs*3, 1, 1).view(x.shape).type_as(x)
grid_y = torch.linspace(0, feat_h-1, feat_h).repeat(feat_w, 1).repeat(bs*3, 1, 1).view(y.shape).type_as(y)
scaled_anchors_l = np.array(scaled_anchors)[self.anchors_mask[l]]
anchor_w = torch.Tensor(scaled_anchors_l).index_select(1, torch.LongTensor([0])).type_as(x)
anchor_h = torch.Tensor(scaled_anchors_l).index_select(1, torch.LongTensor([1])).type_as(x)
anchor_w = anchor_w.repeat(bs, 1).repeat(1, 1, feat_h * feat_w).view(w.shape)
anchor_h = anchor_h.repeat(bs, 1).repeat(1, 1, feat_h * feat_w).view(h.shape)
pred_boxes_x = torch.unsqueeze(x * 2. - 0.5 + grid_x, -1)
pred_boxes_y = torch.unsqueeze(y * 2. - 0.5 + grid_y, -1)
pred_boxes_w = torch.unsqueeze((w * 2) ** 2 * anchor_w, -1)
pred_boxes_h = torch.unsqueeze((h * 2) ** 2 * anchor_h, -1)
pred_boxes = torch.cat([pred_boxes_x, pred_boxes_y, pred_boxes_w, pred_boxes_h], dim=-1)
return pred_boxes