目录
0 前言
之前陆续已经实现了如何使用YOLOV3,YOLOV1进行目标检测,这两天想抽时间,学习一下官方的原源文件,补充一下理论知识。本文使用的源码来源于 https://github.com/hizhangp/yolo_tensorflow 。
1 YOLO_V1代码预览
由上面提供的链接,下载解压后,源文件的构成如下图,
一共有9个对象,其中三个文件夹,6个文件。
3个文件夹
test:存放的用于测试的图片
yolo:里面有三个文件
yolo_net.py:建立yolo网络(YOLONet)yolo_net.py定义了YOLONet类,该类包含了网络的的初始化(__init__()),建立网络(build_networks()),和loss函数(loss_layer())等方法。
config.py:路径和数据集参数设置
__init__.py:空文件
utils:里面有三个文件
pascal_voc.py:读取文件数据
__init__.py:空文件
timer.py:用于计时
6个文件
train.py:训练代码
test.py:测试代码
README:为帮助说明文件
LICENSE:证书
download_data.sh:应该是于下载数据有关的文件,待研究...
.gitignore:使用github提交本地代码时,忽略不必要提交的文件
2 train解析
train是用来训练网络的。
看眼train.py文件的整体框架,如下图:
2.1 先看一下导入的文件
import os # 操作系统操作模块
import argparse # python中的命令行解析模块
import datetime # 处理日期和时间的标准库
import tensorflow as tf # tensorflow框架
import yolo.config as cfg # yolo文件夹下的config.py配置文件
from yolo.yolo_net import YOLONet # yolo文件夹下的yolo_net.py网络构建文件
from utils.timer import Timer # utils文件夹下的timer.py计时文件
from utils.pascal_voc import pascal_voc # utils文件夹下的pascal_voc.py文件读取数据模块
slim = tf.contrib.slim #可以大大减少复杂网络的代码量
2.2 main函数开始
def main():
#用于解析命令行,显示帮助信息,直到......
parser = argparse.ArgumentParser()
parser.add_argument('--weights', default="YOLO_small.ckpt", type=str)
parser.add_argument('--data_dir', default="data", type=str)
parser.add_argument('--threshold', default=0.2, type=float)
parser.add_argument('--iou_threshold', default=0.5, type=float)
parser.add_argument('--gpu', default='', type=str)
args = parser.parse_args()
if args.gpu is not None:
cfg.GPU = args.gpu
if args.data_dir != cfg.DATA_PATH:
update_config_paths(args.data_dir, args.weights)
os.environ['CUDA_VISIBLE_DEVICES'] = cfg.GPU
#......这里
yolo = YOLONet() # 实例化YOLONet对象。YOLONet是在yolo_net.py文件中定义的类
pascal = pascal_voc('train') # 实例化pascal_voc对象。pascal_voc是在pascal_voc.py中定义的类
solver = Solver(yolo, pascal) # 实例化Solver对象。Solver是在本文件中定义的类
print('Start training ...')
solver.train() #调用solver对象的train方法,执行训练。
print('Done training.')
main()函数的前部分,使用argparser模块显示帮助信息,提示用户的。
运行效果:
由于本人的电脑显存较小,运行时报错了:ResourceExhaustedError
应该可以修改参数,完成,在yolo_v3训练时候,就是通过修改batchsize和epoch来实现。
main()函数的下半部分分别调用了使用YOLONet,pascal_voc Solver以及train()方法。下面依次进行介绍。
2.2 YOLONet
同样先看一下yolo_net.py文件的框架
两大任务:定义YOLONet()类,和激活函数leaky_relu()
关于leaky_relu()激活函数参看 激活函数
下面是大佬YOLONet()类上场。按照惯例,先一睹风采。
def __init__(): 网络初始化,包含了网络初始的参数
def build_network(): 建立网络
def calc_iou: 计算iou(Intersection over Union)
def loss_layer(): loss函数
看下面的注释
__init__()
def __init__(self, is_training=True):
##VOC 2012数据集类别名
self.classes = cfg.CLASSES # 类别
self.num_class = len(self.classes) #类别的数量,这里默认值20
self.image_size = cfg.IMAGE_SIZE # 网络输入图像的尺寸默认值448,448×448
self.cell_size = cfg.CELL_SIZE # cell尺寸,默认S = 7,将图像分为SxS的格子
self.boxes_per_cell = cfg.BOXES_PER_CELL # 每个geid cell负责的boxes,默认为2
##网络输出的大小 S*S*(B*5 + C) = 1470
self.output_size = (self.cell_size * self.cell_size) *\
(self.num_class + self.boxes_per_cell * 5) # 输出尺寸
#图片的缩放比例
self.scale = 1.0 * self.image_size / self.cell_size # 1×(448÷7)
self.boundary1 = self.cell_size * self.cell_size * self.num_class # 7×7×20
self.boundary2 = self.boundary1 +\
self.cell_size * self.cell_size * self.boxes_per_cell # 7×7×20 + 7×7×2
'''
boundery1和boundery2 作用是在输出中确定每种信息的长度(如类别,置信度等)。
其中 boundery1 指的是对于所有的 cell 的类别的预测的张量维度,所以是
self.cell_size * self.cell_size * self.num_class
boundery2 指的是在类别之后每个cell 所对应的 bounding boxes 的数量的总和,所以是
self.boundary1 + self.cell_size * self.cell_size * self.boxes_per_cell
'''
# 代价函数 权重
self.object_scale = cfg.OBJECT_SCALE # 默认值1.0
self.noobject_scale = cfg.NOOBJECT_SCALE # 默认值1.0
self.class_scale = cfg.CLASS_SCALE # 默认值2.0
self.coord_scale = cfg.COORD_SCALE # 默认值5.0
self.learning_rate = cfg.LEARNING_RATE # 学习率,默认值0.0001
self.batch_size = cfg.BATCH_SIZE #批次大小,默认值45
self.alpha = cfg.ALPHA #泄露修正线性激活函数 ,m默认系数0.1
# 偏置形状[7,7,2]
self.offset = np.transpose(np.reshape(np.array(
[np.arange(self.cell_size)] * self.cell_size * self.boxes_per_cell),
(self.boxes_per_cell, self.cell_size, self.cell_size)), (1, 2, 0))
#输入图片占位符 [NONE,image_size,image_size,3]
self.images = tf.placeholder(
tf.float32, [None, self.image_size, self.image_size, 3],
name='images') #placeholder()是在网络构建graph的时的占位此时没有把输入的数据传入模型,只分配必内存
#构建网络,获取YOLO网络的输出(不经过激活函数的输出),形状[None,1470]
self