tools.py(Yolov1开源项目代码详细理解)

项目代码: 来自github的YOLOv1开源项目
本文是关于tools.py的详细理解。

该文件包含3个函数和一个类

  1. MSELoss
  2. generate_dxdywh
  3. gt_creator
  4. loss

MSELoss

class MSELoss(nn.Module):
    def __init__(self, reduction='mean'):
        super(MSELoss, self).__init__()#子类的构造函数需要先调用父类的构造函数进行初始化,以确保子类继承了父类的属性和方法。
        self.reduction = reduction
  1. 这个类用于计算均方误差(Mean Squared Error, MSE)损失。
    MSE计算方法: M S E ( y , y ^ ) = 1 n ∑ i = 1 n ( y i − y i ^ ) 2 MSE(y,\hat{y})=\frac{1}{n}\sum_{i=1}^{n}(y_i-\hat{y_i})^2 MSE(y,y^)=n1i=1n(yiyi^)2
    其中 y y y 是真实值, y ^ \hat{y} y^ 是预测值, n n n 是样本数。
    在初始化函数中,通过调用nn.Module的__init__方法来初始化父类。其中,reduction参数用于指定计算损失时采用的降维方式,可以取值为’mean’、‘sum’和’none’。默认为’mean’,表示对损失进行求和后再除以样本数求平均。
    def forward(self, inputs, targets):
        pos_id = (targets==1.0).float()
        neg_id = (targets==0.0).float()
        pos_loss = pos_id * (inputs - targets)**2#对正例目标进行了惩罚
        neg_loss = neg_id * (inputs)**2
        if self.reduction == 'mean':
            pos_loss = torch.mean(torch.sum(pos_loss, 1))
            neg_loss = torch.mean(torch.sum(neg_loss, 1))
            return pos_loss, neg_loss
        else:
            return pos_loss, neg_loss
  1. 这是一个损失函数前向传递函数,这个函数计算的是二分类损失函数中的正样本损失负样本损失。targets是指二分类损失函数中的真实标签,它的取值应该是0或1。torch.sum(pos_loss, 1) 是对 pos_loss 张量在第1个维度上进行求和操作,返回一个尺寸为 (batch_size,) 的张量,torch.mean(input) 将返回所有元素的平均值。
def generate_dxdywh(gt_label, w, h, s):
    xmin, ymin, xmax, ymax = gt_label[:-1]#gt_label中取出了目标框的位置信息
  1. 这个函数的输入中gt_label是一个包含目标框信息的列表,其包含五个元素,分别是xmin、ymin、xmax、ymax和目标框对应的类别编号。xmin、ymin、xmax、ymax分别表示目标框左上角和右下角的相对坐标(取值在0-1之间),目标框对应的类别编号通常是一个整数,用于表示不同类别的目标。w,h则是目标框的宽和高,是具体的像素数值。s是网格的尺寸。
    # compute the center, width and height
    c_x = (xmax + xmin) / 2 * w
    c_y = (ymax + ymin) / 2 * h
    box_w = (xmax - xmin) * w
    box_h = (ymax - ymin) * h
  1. 这行代码是计算目标框的中心点在图像上的横坐标,具体的计算过程为:
    计算目标框的宽度 b o x w box_w boxw b o x w = ( x m a x − x m i n ) × w box_w = (xmax - xmin) \times w boxw=(xmaxxmin)×w
    计算目标框的中心点横坐标 c x c_x cx c x = ( x m a x + x m i n ) / 2 × w c_x = (xmax + xmin) / 2 \times w cx=(xmax+xmin)/2×w
    这里的坐标xmin, ymin, xmax, ymax应该是0-1的相对值。
    if box_w < 1. or box_h < 1.:
        # print('A dirty data !!!')
        return False    
  1. 这段代码用于判断标注框是否合法,如果标注框的宽度或高度小于1,则视为不合法的数据,返回False。
    # map center point of box to the grid cell
    c_x_s = c_x / s
    c_y_s = c_y / s
    grid_x = int(c_x_s)
    grid_y = int(c_y_s)
    # compute the (x, y, w, h) for the corresponding grid cell
    tx = c_x_s - grid_x
    ty = c_y_s - grid_y
    tw = np.log(box_w)
    th = np.log(box_h)
    weight = 2.0 - (box_w / w) * (box_h / h)

    return grid_x, grid_y, tx, ty, tw, th, weight

  1. c_x 是中心点的绝对坐标,s是一个网格的尺寸,c_x_s就是中心点相对于整个网格图相对坐标,grid_x 是该检测框中心点所在的网格的x坐标。tx,ty则是相对于cell的相对坐标,值在0-1之间。tw,th取对数的目的是为了缩小宽度的值域,便于后续的计算和处理。weight对于不同大小的目标框给予不同的权重,以便更好地平衡损失值的大小。公式中的2.0是一个超参数,可以根据具体任务进行调整,一般来说,当目标框越大时,权重越小,因为对于大目标框来说,在偏差值接近的情况下,比起小目标框,它的错误程度更轻,因此给予它更小的权重可以平衡误差。
def gt_creator(input_size, stride, label_lists=[], name='VOC'):
  1. 这个函数主要的作用是把ground truth(真实标签)的信息转换成一个张量(tensor)数据类型。其输入包括input_size表示输入图像的大小,stride表示网格尺寸,label_lists表示目标标注信息,例如目标所在的位置、大小、类别等。最后函数返回一个gt_tensor ,是一个表示 ground-truth 的张量,它的形状是 [batch_size, hs, ws, 1+1+4+1],其中 batch_size 是 batch 的大小,hs 和 ws 是输入图像的高度和宽度按照步长 stride 划分之后的格子数,第三个维度的 7 个元素分别是表示物体是否存在的标志、物体类别、物体中心坐标和宽高的偏移量、权重。
    assert len(input_size) > 0 and len(label_lists) > 0#确保传入的参数input_size和label_lists都是非空的,如果有一个为空则会抛出AssertionError异常
    # prepare the all empty gt datas
    batch_size = len(label_lists)#输入的标签列表的数量
    w = input_size[1]#输入图像的宽度和高度
    h = input_size[0]
        # We  make gt labels by anchor-free method and anchor-based method.
    ws = w // stride
    hs = h // stride
    s = stride
    gt_tensor = np.zeros([batch_size, hs, ws, 1+1+4+1])
  1. 函数会将这些信息用于后续的处理过程。在这里,gt_tensor被初始化为一个形状为 [batch_size, hs, ws, 1+1+4+1] 的全零数组。
    # generate gt whose style is yolo-v1
    for batch_index in range(batch_size):
        for gt_label in label_lists[batch_index]:
            gt_class = int(gt_label[-1])#对于每个标签,它的类别信息存储在其字符串格式的最后一个字符中
            result = generate_dxdywh(gt_label, w, h, s)#generate_dxdywh()函数计算其边框相对于特定网格点的偏移量和尺寸,生成相应的ground truth信息
            if result:   
                grid_x, grid_y, tx, ty, tw, th, weight = result

                if grid_x < gt_tensor.shape[2] and grid_y < gt_tensor.shape[1]:
                    gt_tensor[batch_index, grid_y, grid_x, 0] = 1.0
                    gt_tensor[batch_index, grid_y, grid_x, 1] = gt_class
                    gt_tensor[batch_index, grid_y, grid_x, 2:6] = np.array([tx, ty, tw, th])
                    gt_tensor[batch_index, grid_y, grid_x, 6] = weight
    gt_tensor = gt_tensor.reshape(batch_size, -1, 1+1+4+1)#将gt_tensor从[batch_size, hs, ws, 1+1+4+1]重新排列为[batch_size, hs*ws, 1+1+4+1]

  1. 这段代码使用循环遍历每个batch的所有标签,并且把数据保存到gt_tensor中。
def loss(pred_conf, pred_cls, pred_txtytwth, label):
  1. 这个函数主要是计算目标检测中的损失函数,它输入预测值和标签值,并返回每个损失的值和总体损失的值。包括objectness loss(目标存在性损失)、class loss(类别损失)和box loss(坐标框损失)。
    obj = 5.0
    noobj = 1.0

  1. 在训练过程中,损失函数中的正样本权重设为obj,负样本权重设为noobj。通常情况下,obj比noobj大,因为定位误差比置信度误差更难以消除。定位误差是指模型预测的目标位置和实际目标位置之间的差异。置信度误差是指模型预测的置信度(目标存在的概率)和实际置信度之间的差异。因为置信度误差直接影响模型是否认为存在目标,所以准确度的提升对应的是模型能否成功识别出目标,因此会受到更多的关注。简单来说,置信度误差的影响体现在是否能够检测到目标,而定位误差的影响则体现在模型的预测精度上,因此定位误差更难以纠正。
    # create loss_f
    conf_loss_function = MSELoss(reduction='mean')#实例化MSELoss类
    cls_loss_function = nn.CrossEntropyLoss(reduction='none')
    txty_loss_function = nn.BCEWithLogitsLoss(reduction='none')
    twth_loss_function = nn.MSELoss(reduction='none')

    pred_conf = torch.sigmoid(pred_conf[:, :, 0])
    pred_cls = pred_cls.permute(0, 2, 1)
    pred_txty = pred_txtytwth[:, :, :2]#预测框中心点坐标相对于网格左上角的偏移量
    pred_twth = pred_txtytwth[:, :, 2:]#预测的宽和高
        
    gt_obj = label[:, :, 0].float()#标签中的目标的存在标志
    gt_cls = label[:, :, 1].long()#目标的类别标签
    gt_txtytwth = label[:, :, 2:-1].float()#目标的中心点坐标偏移量和宽高的真实值
    gt_box_scale_weight = label[:, :, -1]
  1. 首先将pred_conf在维度2上取第一个元素并进行sigmoid激活,这里使用sigmoid函数是因为pred_conf的预测值在训练过程中可能会大于1或者小于0(0代表没有物体,1代表有物体),但是objectness score的取值范围必须是[0, 1]。这样做的目的是将预测的objectness score和实际的ground truth objectness score之间的误差最小化。
  2. 然后将pred_cls的第2和第3个维度进行转置,变为(batch_size, num_anchors, num_classes),方便和标签gt_cls计算交叉熵损失函数。
    # objectness loss
    pos_loss, neg_loss = conf_loss_function(pred_conf, gt_obj)
    conf_loss = obj * pos_loss + noobj * neg_loss
  1. 这段代码计算了损失函数中的置信度损失,首先使用 conf_loss_function 计算出预测的置信度和实际置信度之间的误差,使用 objnoobj 对两个部分分别进行加权,得到最终的置信度损失 conf_loss
    # class loss
    cls_loss = torch.mean(torch.sum(cls_loss_function(pred_cls, gt_cls) * gt_obj, 1))
  1. 这段代码计算了分类损失,使用的损失函数是交叉熵损失函数(CrossEntropyLoss)。其中,pred_cls是模型预测的每个先验框(anchor)属于不同类别的概率分布,gt_cls是每个先验框真实的类别标签。在计算损失时,对于每个先验框,只有当其对应的网格单元(grid cell)中存在目标时(即gt_obj为1时),分类损失才会被计算,否则分类损失为0。最终,分类损失是所有存在目标的先验框的分类损失之和的平均值。
    # box loss
    txty_loss = torch.mean(torch.sum(torch.sum(txty_loss_function(pred_txty, gt_txtytwth[:, :, :2]), 2) * gt_box_scale_weight * gt_obj, 1))
    twth_loss = torch.mean(torch.sum(torch.sum(twth_loss_function(pred_twth, gt_txtytwth[:, :, 2:]), 2) * gt_box_scale_weight * gt_obj, 1))

    txtytwth_loss = txty_loss + twth_loss
  1. txty_loss_function是一个二元交叉熵损失函数,用于衡量预测的物体中心点坐标和标注中心点坐标之间的差异,sum() 函数的第二个参数是指在哪个维度上进行求和操作。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值