YoloV2网络标签代码粗解

https://github.com/experiencor/keras-yolo2 代码地址

        近几年,对于目标定位来讲,使用one-stage策略的主流算法有SSD,Yolo V1 ,Yolo V2,Yolo V3。一年前剖析SSD源码,并使用于自己的训练样本中,效果还OK。闲来无事,分析下 Yolo V2,Yolo V3的关键代码,并进行记录。

        从研一接触深度学习以来,这两年也一直做一些视觉方面的实验,主要是OCR方面的东西。之前一直没有去记录自己的想法有点遗憾。现在觉得记录一些东西,可以看到自己一步一步的成长。对于即将毕业的研三老狗,先解析下Yolo V2,Yolo V3的代码把,因为之前一种用的SSD,还有其他文字定位的网络。话不多说,对于一个深度学习的算法来讲,里面关键的知识是网络的输入、输出和损失函数的定义,这些关键信息可以整理明白的话,我们可以使用这些关键信息去处理我们自己想要完成的任务。因此,我们本篇先分析下Yolo V2的输入和输出。

        代码在上面给出了Github地址,是目前start比较多的。基于我的认知进行简单的注释,工作比较low,请见谅。其他具体的细节部分可以阅读论文或参考其他博客。

        下面的代码是preprocessing.py下面的主要部分。   

def __getitem__(self, idx):
        l_bound = idx*self.config['BATCH_SIZE']     
        r_bound = (idx+1)*self.config['BATCH_SIZE']   ##########定义batch_size的左右边界。

        if r_bound > len(self.images):
            r_bound = len(self.images)
            l_bound = r_bound - self.config['BATCH_SIZE']

        instance_count = 0

        x_batch = np.zeros((r_bound - l_bound, self.config['IMAGE_H'], self.config['IMAGE_W'], 3))                         # input images

         #############网络输入的部分   大小为   batch_size*图像高*图像宽*图像通道数
        b_batch = np.zeros((r_bound - l_bound, 1     , 1     , 1    ,  self.config['TRUE_BOX_BUFFER'], 4))   # list of self.config['TRUE_self.config['BOX']_BUFFER'] GT boxes

#############网络的输入部分   大小为 batch_size*1*1*1*自定义目标数量*4       'TRUE_BOX_BUFFER'代表的是每一幅图片中物体包含最多的目标    4 指的是 目标的  中心坐标(x,y)以及宽和高(w,h)  这个输入主要用于计算损失函数 下一部分进行解释
        y_batch = np.zeros((r_bound - l_bound, self.config['GRID_H'],  self.config['GRID_W'], self.config['BOX'], 4+1+len(self.config['LABELS'])))                # desired network output

#############上面是网络的两个输入,现在是网络的输出 大小为batch_size*特征图高*特征图宽*自定义预测目标数量*(4+1+预测目标的种类)  特征图高*特征图宽这里的值为13,因为训练原图大小为416,进行32倍缩放,即416/32=13     'BOX'其实相当于一个特征图中的一个点预测预先设置不同宽高anchor。类似于ssd中的aspect的设置,此处的值为5 ===[0.57273, 0.677385, 1.87446, 2.06253, 3.33843, 5.47434, 7.88282, 3.52778, 9.77052, 9.16828] 该值是通过k-means聚类出来的5组相对最优框。不同的数据集该值是不同的。 4 指的是 目标的  中心坐标(x,y)以及宽和高(w,h) ,1指的是目标的置信度,'LABELS'指的是目标的种类,此处并不包含背景,而ssd中是包含背景但是没有置信度,这是两个的区别。

##################对图片和包含的目标进行处理

        for train_instance in self.images[l_bound:r_bound]:
            # augment input image and fix object's position and size
            img, all_objs = self.aug_image(train_instance, jitter=self.jitter)
            ############### 该函数就是为了对图片进行增强,例如翻转,旋转,缩放等等,可以自己定义。注:图片相对变化时,标注的坐标点也需要相对变换。
            # construct output from object's x, y, w, h
            true_box_index = 0
            #############true_box_index 图片中包含目标个数的索引值,初始为0,接下来遍历图片中的需要定位的目标的坐标
            for obj in all_objs:
                if obj['xmax'] > obj['xmin'] and obj['ymax'] > obj['ymin'] and obj['name'] in self.config['LABELS']:
                    center_x = .5*(obj['xmin'] + obj['xmax'])
                    center_x = center_x / (float(self.config['IMAGE_W']) / self.config['GRID_W'])
                    center_y = .5*(obj['ymin'] + obj['ymax'])
                    center_y = center_y / (float(self.config['IMAGE_H']) / self.config['GRID_H'])

##############计算实际目标坐标相对于缩放一定倍数(32)后的目标中心(x,y)位置的值

                    grid_x = int(np.floor(center_x))
                    grid_y = int(np.floor(center_y))

#############进行向上取整

                    if grid_x < self.config['GRID_W'] and grid_y < self.config['GRID_H']:
                        obj_indx  = self.config['LABELS'].index(obj['name'])
                        
                        center_w = (obj['xmax'] - obj['xmin']) / (float(self.config['IMAGE_W']) / self.config['GRID_W']) # unit: grid cell
                        center_h = (obj['ymax'] - obj['ymin']) / (float(self.config['IMAGE_H']) / self.config['GRID_H']) # unit: grid cell
                        #############同理计算缩放后的目标宽和高的值(w,h)
                        box = [center_x, center_y, center_w, center_h]

                        # find the anchor that best predicts this box
                        best_anchor = -1
                        max_iou     = -1
                        ###############寻找根据先验知识预先设置的5个框,与现在实际映射后的IOU最大的一个框的索引值和其IOU的值
                        shifted_box = BoundBox(0, 
                                               0,
                                               center_w,                                                
                                               center_h)
                        
                        for i in range(len(self.anchors)):
                            anchor = self.anchors[i]
                            iou    = bbox_iou(shifted_box, anchor)
                            
                            if max_iou < iou:
                                best_anchor = i
                                max_iou     = iou
                         ################       best_anchor  记录最佳的1个框的索引,max_iou     记录最佳IOU值
                        # assign ground truth x, y, w, h, confidence and class probs to y_batch
                        y_batch[instance_count, grid_y, grid_x, best_anchor, 0:4] = box
                        y_batch[instance_count, grid_y, grid_x, best_anchor, 4  ] = 1.
                        y_batch[instance_count, grid_y, grid_x, best_anchor, 5+obj_indx] = 1
                        ##########记录实际坐标值   0-4维存放box信息,置信度和目标的种类信息。
                        # assign the true box to b_batch
                        b_batch[instance_count, 0, 0, 0, true_box_index] = box
                        ########记录图片中包含目标个数信息的box信息。
                        true_box_index += 1
                        true_box_index = true_box_index % self.config['TRUE_BOX_BUFFER']
                            
            # assign input image to x_batch

###############根据使用不同的网络需要对图片进行norm化
            if self.norm != None: 
                x_batch[instance_count] = self.norm(img)
            else:
                # plot image and bounding boxes for sanity check
                for obj in all_objs:
                    if obj['xmax'] > obj['xmin'] and obj['ymax'] > obj['ymin']:
                        cv2.rectangle(img[:,:,::-1], (obj['xmin'],obj['ymin']), (obj['xmax'],obj['ymax']), (255,0,0), 3)
                        cv2.putText(img[:,:,::-1], obj['name'], 
                                    (obj['xmin']+2, obj['ymin']+12), 
                                    0, 1.2e-3 * img.shape[0], 
                                    (0,255,0), 2)
                        
                x_batch[instance_count] = img

            # increase instance counter in current batch
            instance_count += 1  

        #print(' new batch created', idx)

############返回网络所需要的batch_size大小的输入和输出数据

        return [x_batch, b_batch], y_batch

        看到Yolo V2训练的label格式,大致就能猜想到使用的损失函数,大家可以尝试不同的网络结构去训练自己的样本,适合自己的才是最好的。原始网络为了增加识别小目标的,是使用中间层的特征拼接最后一层的特征,这点类似于SSD的多尺度。下一章介绍下损失函数,有了这两部分,就可以用于自己的数据实验。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值