项目代码: 来自github的YOLOv1开源项目
本文是关于tools.py的详细理解。
该文件包含3个函数和一个类
- MSELoss
- generate_dxdywh
- gt_creator
- loss
MSELoss
class MSELoss(nn.Module):
def __init__(self, reduction='mean'):
super(MSELoss, self).__init__()#子类的构造函数需要先调用父类的构造函数进行初始化,以确保子类继承了父类的属性和方法。
self.reduction = reduction
- 这个类用于计算均方误差(Mean Squared Error,
MSE
)损失。
MSE计算方法: M S E ( y , y ^ ) = 1 n ∑ i = 1 n ( y i − y i ^ ) 2 MSE(y,\hat{y})=\frac{1}{n}\sum_{i=1}^{n}(y_i-\hat{y_i})^2 MSE(y,y^)=n1∑i=1n(yi−yi^)2
其中 y y y 是真实值, y ^ \hat{y} y^ 是预测值, n n n 是样本数。
在初始化函数中,通过调用nn.Module的__init__
方法来初始化父类。其中,reduction
参数用于指定计算损失时采用的降维方式,可以取值为’mean’、‘sum’和’none’。默认为’mean’,表示对损失进行求和后再除以样本数求平均。
def forward(self, inputs, targets):
pos_id = (targets==1.0).float()
neg_id = (targets==0.0).float()
pos_loss = pos_id * (inputs - targets)**2#对正例目标进行了惩罚
neg_loss = neg_id * (inputs)**2
if self.reduction == 'mean':
pos_loss = torch.mean(torch.sum(pos_loss, 1))
neg_loss = torch.mean(torch.sum(neg_loss, 1))
return pos_loss, neg_loss
else:
return pos_loss, neg_loss
- 这是一个损失函数前向传递函数,这个函数计算的是二分类损失函数中的正样本损失和负样本损失。targets是指二分类损失函数中的真实标签,它的取值应该是0或1。
torch.sum(pos_loss, 1)
是对 pos_loss 张量在第1个维度上进行求和操作,返回一个尺寸为 (batch_size,) 的张量,torch.mean(input)
将返回所有元素的平均值。
def generate_dxdywh(gt_label, w, h, s):
xmin, ymin, xmax, ymax = gt_label[:-1]#gt_label中取出了目标框的位置信息
- 这个函数的输入中
gt_label
是一个包含目标框信息的列表,其包含五个元素,分别是xmin、ymin、xmax、ymax和目标框对应的类别编号。xmin、ymin、xmax、ymax分别表示目标框左上角和右下角的相对坐标(取值在0-1之间),目标框对应的类别编号通常是一个整数,用于表示不同类别的目标。w,h
则是目标框的宽和高,是具体的像素数值。s
是网格的尺寸。
# compute the center, width and height
c_x = (xmax + xmin) / 2 * w
c_y = (ymax + ymin) / 2 * h
box_w = (xmax - xmin) * w
box_h = (ymax - ymin) * h
- 这行代码是计算目标框的中心点在图像上的横坐标,具体的计算过程为:
计算目标框的宽度 b o x w box_w boxw : b o x w = ( x m a x − x m i n ) × w box_w = (xmax - xmin) \times w boxw=(xmax−xmin)×w
计算目标框的中心点横坐标 c x c_x cx : c x = ( x m a x + x m i n ) / 2 × w c_x = (xmax + xmin) / 2 \times w cx=(xmax+xmin)/2×w
这里的坐标xmin, ymin, xmax, ymax
应该是0-1的相对值。
if box_w < 1. or box_h < 1.:
# print('A dirty data !!!')
return False
- 这段代码用于判断标注框是否合法,如果标注框的宽度或高度小于1,则视为不合法的数据,返回False。
# map center point of box to the grid cell
c_x_s = c_x / s
c_y_s = c_y / s
grid_x = int(c_x_s)
grid_y = int(c_y_s)
# compute the (x, y, w, h) for the corresponding grid cell
tx = c_x_s - grid_x
ty = c_y_s - grid_y
tw = np.log(box_w)
th = np.log(box_h)
weight = 2.0 - (box_w / w) * (box_h / h)
return grid_x, grid_y, tx, ty, tw, th, weight
c_x
是中心点的绝对坐标,s
是一个网格的尺寸,c_x_s
就是中心点相对于整个网格图相对坐标,grid_x
是该检测框中心点所在的网格的x坐标。tx,ty
则是相对于cell的相对坐标,值在0-1之间。tw,th
取对数的目的是为了缩小宽度的值域,便于后续的计算和处理。weight
对于不同大小的目标框给予不同的权重,以便更好地平衡损失值的大小。公式中的2.0是一个超参数,可以根据具体任务进行调整,一般来说,当目标框越大时,权重越小,因为对于大目标框来说,在偏差值接近的情况下,比起小目标框,它的错误程度更轻,因此给予它更小的权重可以平衡误差。
def gt_creator(input_size, stride, label_lists=[], name='VOC'):
- 这个函数主要的作用是把ground truth(真实标签)的信息转换成一个张量(tensor)数据类型。其输入包括
input_size
表示输入图像的大小,stride
表示网格尺寸,label_lists
表示目标标注信息,例如目标所在的位置、大小、类别等。最后函数返回一个gt_tensor
,是一个表示 ground-truth 的张量,它的形状是[batch_size, hs, ws, 1+1+4+1]
,其中 batch_size 是 batch 的大小,hs 和 ws 是输入图像的高度和宽度按照步长 stride 划分之后的格子数,第三个维度的 7 个元素分别是表示物体是否存在的标志、物体类别、物体中心坐标和宽高的偏移量、权重。
assert len(input_size) > 0 and len(label_lists) > 0#确保传入的参数input_size和label_lists都是非空的,如果有一个为空则会抛出AssertionError异常
# prepare the all empty gt datas
batch_size = len(label_lists)#输入的标签列表的数量
w = input_size[1]#输入图像的宽度和高度
h = input_size[0]
# We make gt labels by anchor-free method and anchor-based method.
ws = w // stride
hs = h // stride
s = stride
gt_tensor = np.zeros([batch_size, hs, ws, 1+1+4+1])
- 函数会将这些信息用于后续的处理过程。在这里,
gt_tensor
被初始化为一个形状为 [batch_size, hs, ws, 1+1+4+1] 的全零数组。
# generate gt whose style is yolo-v1
for batch_index in range(batch_size):
for gt_label in label_lists[batch_index]:
gt_class = int(gt_label[-1])#对于每个标签,它的类别信息存储在其字符串格式的最后一个字符中
result = generate_dxdywh(gt_label, w, h, s)#generate_dxdywh()函数计算其边框相对于特定网格点的偏移量和尺寸,生成相应的ground truth信息
if result:
grid_x, grid_y, tx, ty, tw, th, weight = result
if grid_x < gt_tensor.shape[2] and grid_y < gt_tensor.shape[1]:
gt_tensor[batch_index, grid_y, grid_x, 0] = 1.0
gt_tensor[batch_index, grid_y, grid_x, 1] = gt_class
gt_tensor[batch_index, grid_y, grid_x, 2:6] = np.array([tx, ty, tw, th])
gt_tensor[batch_index, grid_y, grid_x, 6] = weight
gt_tensor = gt_tensor.reshape(batch_size, -1, 1+1+4+1)#将gt_tensor从[batch_size, hs, ws, 1+1+4+1]重新排列为[batch_size, hs*ws, 1+1+4+1]
- 这段代码使用循环遍历每个batch的所有标签,并且把数据保存到
gt_tensor
中。
def loss(pred_conf, pred_cls, pred_txtytwth, label):
- 这个函数主要是计算目标检测中的损失函数,它输入预测值和标签值,并返回每个损失的值和总体损失的值。包括objectness loss(目标存在性损失)、class loss(类别损失)和box loss(坐标框损失)。
obj = 5.0
noobj = 1.0
- 在训练过程中,损失函数中的正样本权重设为
obj
,负样本权重设为noobj
。通常情况下,obj比noobj大,因为定位误差比置信度误差更难以消除。定位误差是指模型预测的目标位置和实际目标位置之间的差异。置信度误差是指模型预测的置信度(目标存在的概率)和实际置信度之间的差异。因为置信度误差直接影响模型是否认为存在目标,所以准确度的提升对应的是模型能否成功识别出目标,因此会受到更多的关注。简单来说,置信度误差的影响体现在是否能够检测到目标,而定位误差的影响则体现在模型的预测精度上,因此定位误差更难以纠正。
# create loss_f
conf_loss_function = MSELoss(reduction='mean')#实例化MSELoss类
cls_loss_function = nn.CrossEntropyLoss(reduction='none')
txty_loss_function = nn.BCEWithLogitsLoss(reduction='none')
twth_loss_function = nn.MSELoss(reduction='none')
pred_conf = torch.sigmoid(pred_conf[:, :, 0])
pred_cls = pred_cls.permute(0, 2, 1)
pred_txty = pred_txtytwth[:, :, :2]#预测框中心点坐标相对于网格左上角的偏移量
pred_twth = pred_txtytwth[:, :, 2:]#预测的宽和高
gt_obj = label[:, :, 0].float()#标签中的目标的存在标志
gt_cls = label[:, :, 1].long()#目标的类别标签
gt_txtytwth = label[:, :, 2:-1].float()#目标的中心点坐标偏移量和宽高的真实值
gt_box_scale_weight = label[:, :, -1]
- 首先将
pred_conf
在维度2上取第一个元素并进行sigmoid
激活,这里使用sigmoid函数是因为pred_conf的预测值在训练过程中可能会大于1或者小于0(0代表没有物体,1代表有物体),但是objectness score的取值范围必须是[0, 1]。这样做的目的是将预测的objectness score和实际的ground truth objectness score之间的误差最小化。 - 然后将
pred_cls
的第2和第3个维度进行转置,变为(batch_size, num_anchors, num_classes),方便和标签gt_cls
计算交叉熵损失函数。
# objectness loss
pos_loss, neg_loss = conf_loss_function(pred_conf, gt_obj)
conf_loss = obj * pos_loss + noobj * neg_loss
- 这段代码计算了损失函数中的置信度损失,首先使用
conf_loss_function
计算出预测的置信度和实际置信度之间的误差,使用obj
和noobj
对两个部分分别进行加权,得到最终的置信度损失conf_loss
。
# class loss
cls_loss = torch.mean(torch.sum(cls_loss_function(pred_cls, gt_cls) * gt_obj, 1))
- 这段代码计算了分类损失,使用的损失函数是交叉熵损失函数(CrossEntropyLoss)。其中,pred_cls是模型预测的每个先验框(anchor)属于不同类别的概率分布,gt_cls是每个先验框真实的类别标签。在计算损失时,对于每个先验框,只有当其对应的网格单元(grid cell)中存在目标时(即gt_obj为1时),分类损失才会被计算,否则分类损失为0。最终,分类损失是所有存在目标的先验框的分类损失之和的平均值。
# box loss
txty_loss = torch.mean(torch.sum(torch.sum(txty_loss_function(pred_txty, gt_txtytwth[:, :, :2]), 2) * gt_box_scale_weight * gt_obj, 1))
twth_loss = torch.mean(torch.sum(torch.sum(twth_loss_function(pred_twth, gt_txtytwth[:, :, 2:]), 2) * gt_box_scale_weight * gt_obj, 1))
txtytwth_loss = txty_loss + twth_loss
txty_loss_function
是一个二元交叉熵损失函数,用于衡量预测的物体中心点坐标和标注中心点坐标之间的差异,sum()
函数的第二个参数是指在哪个维度上进行求和操作。