yolov1模型结构和训练测试流程详解

一、网络结构
在这里插入图片描述
①首先经过一个VGG主干网络提取特征,这里的主干网络可以自己选择,使用resnet也可以。
②reshape为一维,然后进行全连接,in_dim=25088,out_dim=4096,需要注意的是这里的25088是由51277得到的,而不同大小的图像经过主干网络提取特征后的大小也是不一样的,所以如果输入的图像大小不是448的话,就需要修改这里的in_dim。
③第二个全连接层,in_dim=4096,out_dim=1470,这里的out_dim是由7730得到的,是为了reshape成7×7×30的向量,至于为什么yolov1的输出是一个7×7×30的向量,而不是r-cnn系列的具体的框坐标等信息,后面会解释。

二、数据处理
①在训练时,会对数据进行随机的翻转和裁剪等操作

img, boxes = self.random_flip(img, boxes)#随机翻转
img,boxes = self.randomScale(img,boxes)#伸缩变形
img = self.randomBlur(img)#平滑处理
img = self.RandomBrightness(img)#亮度调节
img = self.RandomHue(img)#色度调节
img = self.RandomSaturation(img)#饱和度调节
img,boxes,labels = self.randomShift(img,boxes,labels)#随机平移
img,boxes,labels = self.randomCrop(img,boxes,labels)#随机裁剪

②训练和测试,会对输入网络前的数据进行去均值和resize,需要注意的是,如果想要设图像为448×448以外的size,需要修改model中全连接层的参数

boxes /= torch.Tensor([w,h,w,h]).expand_as(boxes)#坐标归一化处理
img = self.BGR2RGB(img)
img = self.subMean(img,self.mean)#减去均值
img = cv2.resize(img,(self.image_size,self.image_size))#调整图片尺寸为448*448

③训练前最重要的一步,将标签decode到7×7×30的向量(这里的7代表检测的密度,可以自己设定,会改变检测精度和性能;这里的30代表((4+1)*2+20),4表示检测框的中心点坐标x、y和宽高w、h,1表示是否是被检测物体的置信度,2表示回归两个矩形框,20表示20个类别)
在这里插入图片描述

三、训练
①输入数据:image(3×448×448),label(7×7×30)
②训练策略:没什么好写的,就是平常的训练步骤
③loss函数:YOLO给出的损失函数如下
在这里插入图片描述
此处给出的loss函数解释(转载https://www.jianshu.com/p/cad68ca85e27):
公式中
l i o b j l_i^{obj} liobj意思是网格i中存在对象。
l i j o b j l_{ij}^{obj} lijobj意思是网格i的第j个bounding box中存在对象。
l i j n o o b j l_{ij}^{noobj} lijnoobj意思是网格i的第j个bounding box中不存在对象。
总的来说,就是用网络输出与样本标签的各项内容的误差平方和作为一个样本的整体误差。
损失函数中的几个项是与输出的30维向量中的内容相对应的。
①对象分类的误差
公式第5行,注意 l i o b j l_i^{obj} liobj意味着存在对象的网格才计入误差。
②bounding box的位置误差
公式第1行和第2行。
a)都带有 l i j o b j l_{ij}^{obj} lijobj意味着只有"负责"(IOU比较大)预测的那个bounding box的数据才会计入误差。
b)第2行宽度和高度先取了平方根,因为如果直接取差值的话,大的对象对差值的敏感度较低,小的对象对差值的敏感度较高,所以取平方根可以降低这种敏感度的差异,使得较大的对象和较小的对象在尺寸误差上有相似的权重。
c)乘以 λ c o o r d λ_{coord} λcoord调节bounding box位置误差的权重(相对分类误差和置信度误差)。YOLO设置 λ c o o r d λ_{coord} λcoord=5,即调高位置误差的权重。
③bounding box的置信度误差
公式第3行和第4行。
a)第3行是存在对象的bounding box的置信度误差。带有 l i j o b j l_{ij}^{obj} lijobj意味着只有"负责"(IOU比较大)预测的那个bounding box的置信度才会计入误差。
b)第4行是不存在对象的bounding box的置信度误差。因为不存在对象的bounding box应该老老实实的说"我这里没有对象",也就是输出尽量低的置信度。如果它不恰当的输出较高的置信度,会与真正"负责"该对象预测的那个bounding box产生混淆。其实就像对象分类一样,正确的对象概率最好是1,所有其它对象的概率最好是0。
c)第4行会乘以 λ n o o b j λ_{noobj} λnoobj调节不存在对象的bounding box的置信度的权重(相对其它误差)。YOLO设置 λ n o o b j λ_{noobj} λnoobj=0.5,即调低不存在对象的bounding box的置信度误差的权重。

四、测试
网络的输出是7×7×30的向量,需要对其decode转换为人可以理解的数据。
①首先输入一张图像得到7×7×30的输出,对每个ceil的两个框共2×7×7个框进行筛选,取置信度(是否有物体)大于0.1的,并且一个ceil中有两个大于0.1框的取得分较大的那个。

pred = pred.data
pred = pred.squeeze(0)
contain1 = pred[:,:,4].unsqueeze(2)
contain2 = pred[:,:,9].unsqueeze(2)
contain = torch.cat((contain1,contain2),2)
mask1 = contain > 0.1
mask2 = (contain==contain.max()) 
mask = (mask1+mask2).gt(0)

②将筛选后的框取置信度(概率最大的类别的概率)大于0.1的,然后转换为(x1,y1,x2,y2)的格式。

for i in range(grid_num):
    for j in range(grid_num):
        for b in range(2):
            if mask[i,j,b] == 1:
                box = pred[i,j,b*5:b*5+4]
                contain_prob = torch.FloatTensor([pred[i,j,b*5+4]])
                xy = torch.FloatTensor([j,i])*cell_size 
                box[:2] = box[:2]*cell_size + xy 
                box_xy = torch.FloatTensor(box.size())
                box_xy[:2] = box[:2] - 0.5*box[2:]
                box_xy[2:] = box[:2] + 0.5*box[2:]
                max_prob,cls_index = torch.max(pred[i,j,10:],0)
                if float((contain_prob*max_prob)[0]) > 0.1:
                    boxes.append(box_xy.view(1,4))
                    cls_indexs.append(cls_index)
                    probs.append(contain_prob*max_prob)

③对筛选后的检测框进行nms,去除重复框,具体的nms算法网上有很多

keep = nms(boxes,probs)
return boxes[keep],cls_indexs[keep],probs[keep]

五、总结
至此,yolov1所有的细节就都介绍完了,从yolov1的训练和检测流程可以看出,yolov1没有选取候选框这一步,速度相较于r-cnn系列有非常大的提升,但是大家也可以发现,如果大家不修改网络,yolov1最多只可以检测出7*7=49个目标,对于密集目标检测远远不如r-cnn系列。而且由于下采样操作,细节信息丢失,对于检测框的精度也不如fast-rcnn。

  • 4
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值