有时间会长期更新,主要是练一练自己的代码
前言
为了让自己安心学习,也为了可能会给更多的初学者帮助,因此我会在加下来这段时间开始更新YOLO算法的复现。
loss设计
# gt: batch_size*7*7*(5+20) 真实框
# pr: batch_size*3*7*7*(5+20) #预测框
class Loss_function():
def __init__(self):
self.size = 7 #特征图尺寸
self.num = 3 #多尺度预测
def calculate(self,gt,pr):
no_obj = 0 #无目标损失
loc_obj = 0 #有目标定位损失
cls_obj = 0 #有目标类别损失
for q in range(len(gt)): #遍历batch_size
for i in range(self.num): # self.num = 3
for j in range(self.size): # self.size = 7
for k in range(self.size): # self.size = 7
if gt[q, j, k, 4] == 1:
no_obj += 1 - pr[q,i, j, k, 4]
lx = (gt[q, j, k, 0] - pr[q, i, j, k, 0]) / pr[q, i, j, k, 2]
ly = (gt[q, j, k, 1] - pr[q, i, j, k, 1]) / pr[q, i, j, k, 3]
lw = np.log2(gt[q, j, k, 2] / pr[q, i, j, k, 2])
lh = np.log2(gt[q, j, k, 3] / pr[q, i, j, k, 3])
loc_obj += lx + ly + lw + lh
cls_obj += sum(np.abs(gt[q, j, k, 5:] - pr[q, i, j, k, 5:]))
else:
no_obj += pr[q, i, j, k, 4]
return no_obj+loc_obj+cls_obj
#测试
loss = Loss_function()
gt = np.random.random(size=(2,7,7,25))
pr = np.random.random(size=(2,3,7,7,25))
print(loss.calculate(gt,pr))
总结
非常简单的损失函数设计