阅读utils/loss.py,掌握YOLO算法的loss计算过程(build_target)

转载自:读loss.py - Yuxi001 - 博客园 (cnblogs.com)

这个损失函数跟YOLO-V5的损失函数相同,最关键的函数是build_target原理很抽象,代码更抽象。想要读懂耗时大概2~3天。由于学长时间有限,所以就只把最关键的部分讲一下。

原理解析

对于某一张输入图片,YOLO算法把图片划分成多个网格(e.g. 13x13),然后根据图片中目标所在的位置,使用对应位置的网格来进行预测,这一步骤实际上可以理解成:把目标分配给某一个网格。在使用anchor的情况下,每一个网格包含多个anchor,此时还要把目标分配给网格中的某一个anchor。所以整体上看就是把目标分配给anchor,称为“anchor assign”。这就是build_target函数所做的事情。它大致分成以下几个步骤:

  1. 取出某一个批次图片中所有的目标。

  2. 根据每一个目标的位置x,y,找到负责预测它的anchor。

  3. 通过高维数组的特殊索引操作:"fancy indexing",取出被分配到目标的anchor对应的预测值

  4. 计算每个anchor的损失值,然后就可以反向传播了

代码解析 

  1.  取出某一个批次图片中所有的目标

utils/datasets.py中,有一个函数:

def collate_fn(batch): 
    img, label = zip(*batch) 
    for i, l in enumerate(label): 
        if l.shape[0] > 0: 
            l[:, 0] = i 
    return torch.stack(img), torch.cat(label, 0)

各个变量的含义如下:

  • batch:DataLoader调用batch_size次TensorDataset类的__getitem__函数,获得的返回值放在一起就是batch。

  • img,label:zip(*xxx)利用 * 号操作符,可以将元组解压为列表。__getitem__函数会返回一张图片跟它的标签,图片形状是(通道数,高,宽),简写成(c, h, w),标签是个矩阵,形状是(目标个数,6)。每一个目标本来只有5个属性x, y, w, h, class_id,但是多预留一个属性,用来保存这个目标所在的图片在这一个批次中的索引。

  • 返回值:有俩。第一个是把图片拼起来,形状是(batch_size, c, h, w)。第二个是批次中所有的目标,是一个矩阵,形状是(所有目标数量,6)。

这个函数的返回值就包含了这个批次图片中所有的目标。

注:“(所有目标数量,6)”中的6,指的就是每个目标的属性:"该图片在这个batch中的索引img_index和x,y,w,h,class_id"

2. 把每一个目标分配给合适的anchor

utils/loss.pybuild_target函数完成这件事儿。细分成以下几个步骤:

        1.函数输入值

preds:网络的输出值,长度为6的元组。

targets:函数collate_fn的第二个返回值,形状是(所有目标数量,6)。

        2.加载配置文件中设定的anchor

#加载anchor配置 
anchors = np.array(cfg["anchors"]) 
anchors = torch.from_numpy(anchors.reshape(len(preds) // 3, anchor_num, 2)).to(device)

变量anchors的形状是(尺度数量,anchor数量,2),2表示每一个anchor的宽高。如果你不知道尺度是什么意思,说明看的不认真,往回看看。

        3.谜语部分

gain = torch.ones(7, device = device)

at = torch.arange(anchor_num, device = device).float().view(anchor_num, 1).repeat(1, label_num)
targets = torch.cat((targets.repeat(anchor_num, 1, 1), at[:, :, None]), 2)

g = 0.5  # bias
off = torch.tensor([[0, 0],
                    [1, 0], [0, 1], [-1, 0], [0, -1],  # j,k,l,m
                    # [1, 1], [1, -1], [-1, 1], [-1, -1],  # jk,jm,lk,lm
                   ], device = device).float() * g  # offsets

这一部分代码充分体现了乱起变量名带来的后果。各个变量的含义:

  • gain:后面就知道了,用来存在特征图的大小。形状是(7, ),并且初始化为1。

  • at:后面就知道了。形状是(anchor_num,label_num),label_num就是整个批次图片包含的目标数量。可以查一下repeat函数的作用。
  • targets.repeat(anchor_num, 1, 1):(anchor_num, label_num, 6)
  • at[:, :, None]:(anchor_num,label_num, 1)

    由于每一个目标(label)都会分配给某一个网格,每一个网格中的anchor数量为anchor_num,所以targets先把目标复制了anchor_num遍,等价于把目标同时分配给了网格内的所有anchor。然后后面的代码中会以"向量化"的方式判断这几个anchor是否适合这个目标。
  • goff:后面就知道了。

        4.关键部分

  • 进入for循环。首先是:
#将label坐标映射到特征图上 

gain[2:6] = torch.tensor(pred.shape)[[3, 2, 3, 2]] 

gt = targets * gain
  • 各个表达式的含义:
  • torch.tensor(pred.shape)[[3, 2, 3, 2]]:用了一个非常简单的fancy indexing,pred.shape返回值是(n, c, h, w),w, h是特征图的宽高,由于这个特征图是网络预测值,所以如果网格划分是13x13的话,那么h, w都是13。索引2对应的是h,3对应的是w,所以这个表达式返回值是(w, h, w, h)。

  • gain[2:6] = torch.tensor(pred.shape)[[3, 2, 3, 2]]:把(w, h, w, h)赋值给gain,此时gain的内容是:(1, 1, w, h, w, h, 1)

  • gt = targets * gain:targets的形状是(anchor_num,label_num, 7),7个分量的含义是:(img_index, x, y, w, h, class_id, anchor_id),乘上gain后会发生"数组广播操作",相当于这7个分量跟gain对应位置相乘,把坐标从原来的百分比坐标转成了特征图坐标。gt的形状是(anchor_num,label_num, 7),仍然表示批次内所有的目标。举个例子,算了不举了。

  • 然后是:
 #anchor iou匹配
    r = gt[:, :, 4:6] / anchors_cfg[:, None]
    j = torch.max(r, 1. / r).max(2)[0] < 2

    t = gt[j]

各行代码的含义:

  • 第一行:之前咱们把某一个目标分配给了所属的网格中所有的anchor,这一行就是计算目标跟分配的anchor的长宽比。

  • 第二行:根据长宽比判断是否合适,返回的j是一个长度跟gt一样的一维bool数组,用来表示每一个目标跟分配给它的anchor是否合适。合适的标准就是目标大小跟anchor差不多。

  • 第三行:过滤出那些跟分配的anchor很合适,已经找到了自己的归宿的目标。

  • 最后是:
#扩充维度并复制数据
    # Offsets
    gxy = t[:, 2:4]  # grid xy
    gxi = gain[[2, 3]] - gxy  # inverse
    j, k = ((gxy % 1. < g) & (gxy > 1.)).T
    l, m = ((gxi % 1. < g) & (gxi > 1.)).T
    j = torch.stack((torch.ones_like(j), j, k, l, m))
    t = t.repeat((5, 1, 1))[j]
    offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
  • 首先需要知道YoloV5所提出的特殊的anchor分配策略,这个网上资料很多,所以大家自行查找就行。简单来说就是:
  1. 如果目标的中心点(x,y)落在所属网格的左上方,那么除了所属的网格外,额外让左边跟上面的两个网格也预测这个目标。

  2. 如果目标的中心点(x,y)落在所属网格的右上方,那么除了所属的网格外,额外让右边跟上面的两个网格也预测这个目标。

  3. 以此类推。

  • 这一小部分代码大家自己看,难点在于怎么用高维数组操作来表示上面的扩充过程,网上(知乎,bilibili)也有一些解析,搜索yolov5 build_target就能找到。
  • 后面的函数返回值部分应该很简单就看懂了。

 3.通过高维数组的特殊索引操作:"fancy indexing",取出被分配到目标的anchor对应的预测值

  • 对应compute_loss函数中的这段  代码:
#构建gt
    tcls, tbox, indices, anchors = build_target(preds, targets, cfg, device)

    for i, pred in enumerate(preds):
        #计算reg分支loss
        if i % 3 == 0:
            pred = pred.reshape(pred.shape[0], cfg["anchor_num"], -1, pred.shape[2], pred.shape[3])
            pred = pred.permute(0, 1, 3, 4, 2)
            
            #判断当前batch数据是否有gt
            if len(indices):
                b, a, gj, gi = indices[layer_index[i]]
                nb = b.shape[0]

                if nb:
                    ps = pred[b, a, gj, gi]
  • 首先接收build_target函数的返回值,也就是anchor assign的结果。然后把返回值作为索引取出模型预测值的对应部分。这一行就解释了build_target函数的最终目的。
  • 后面就是计算损失值了。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值