基于TensorFlow2.0的YOLOV2训练过程

本博客详细介绍了基于TensorFlow2.0实现YOLOv2目标检测的过程,包括环境配置、数据集准备、模型搭建、损失函数计算等关键步骤,适合深度学习初学者及进阶者参考。
摘要由CSDN通过智能技术生成

目录

 1. 环境配置

1.1 Anaconda安装

1.2 Pycharm安装

1.3 TensorFlow安装

2. 训练数据集准备

2.1 数据集标注

3. 训练数据集预处理

3.1 解析标签文件XML

3.2 读取图片

4. 真实标签格式处理

4.1 单张图片 

4.2 批量图片

5. 模型搭建与权重初始化

5.1 模型搭建

5.2 权重初始化

6. 损失计算

6.1 制作网格坐标

6.2 坐标损失计算

6.3 类别损失计算

6.4 置信度计算

7. 模型训练与保存

8. 模型验证


运行环境:

Python:3.6

TensorFlow: 2.0.0+

cuda: 10.0

cundnn: 7.4

Pycharm: 发行版

 1. 环境配置

1.1 Anaconda安装

我使用的是Windows系统,当然,使用Ubuntu也可以,没有什么区别。

下载Anaconda3,下载链接:https://pan.baidu.com/s/1xzrb7kqigl5SYigVO2NdWw,提取码:41tg 

将Anaconda3下载完成后,然后安装。

1.2 Pycharm安装

下载Pycharm, 下载链接:https://pan.baidu.com/s/1SOhs72JK9YY6GAFImrwdBQ,提取码:bqsn

将Pycharm下载完成后,然后安装

1.3 TensorFlow安装

1. 创建一个Python虚拟环境,使用Anaconda Prompt 或者 Anaconda Navigator都可以,我使用的是Prompt, ubuntu系统可以使用终端或者Navigator。

conda create -n Tensorflow-GPU python=3.6

环境的名字可以任意选择。

2. 激活环境,在该环境中安装TensorFlow2.0,我这里介绍一种简单的方法。

conda install tensorflow-gpu==2.0.0 #gpu版本

# conda install tensorflow==2.0.0 #cpu版本

 通过该命令会将TensorFlow-gpu版本自动安装成功,包含配套的cuda, cudnn。在ubuntu上一样的命令,如果安装失败,一般都是因为网速的问题,可以考虑将conda的源换为国内源,这里就不再多赘述,CSDN中有很多博客介绍。

3. 打开Pycharm配置环境即可。

2. 训练数据集准备

目标检测数据集一般是VOC格式的,YOLO与SSD都是这种格式。

2.1 数据集标注

1. 首先将采集好的原图,全部resize成网络输入的尺寸,比如YOLOV2的输入尺寸是512X512。

# -*- coding: utf-8 -*-
import cv2
import os

def rebuild(path_src, path_dst, width, height):
    """
    :param path_src: 原图相对地址
    :param path_dst: 保存图相对地址
    :return: None
    """
    i = 1
    image_names = os.listdir(path_src)
    for image in image_names:
        if image.endswith('.jpg') or image.endswith('.png'):
            img_path = path_src + image
            save_path = path_dst + image
            img = cv2.imread(img_path)
            resize_img = cv2.resize(img, (width, height))
            cv2.imwrite(save_path, resize_img)
            print("修改第 " + str(i), " 张图片:", save_path)
            i = i + 1

if __name__ == "__main__":
    # 原图相对地址,也可以使用绝对地址
    path_src = "pikachu/"
    # 保存图相对地址,也可以使用绝对地址
    path_dst = "pikachu_new/"
    width = 512
    heght = 512
    rebuild(path_src, path_dst, width, heght)

 2. 使用labelImg进行目标标注,使用别的标注工具也可以

labelImg安装方法1:直接下载软件,然后放在桌面双击打开即可,不需要安装

链接:https://pan.baidu.com/s/1_wdd_tChBCrfcicKC-Nxgg 提取码:tsz7

labelImg安装方法2:去github下载源码编译, github链接:https://github.com/tzutalin/labelImg

3. 训练数据集预处理

3.1 解析标签文件XML

请下载文件://download.csdn.net/download/qq_37116150/12289197

该文件包含完整代码

每张图片的标签信息全部保存在.xml(使用labelImg标注图片生成的文件)文件中,标签文件中包含原图路径,原图名,目标位置信息(左上角坐标,右下角坐标,够成一个矩形框),类别名,我们需要的是原图路径, 目标位置信息以及类别名,所有我们需要将这些信息从xml标签文件中提取出来。

xml_parse.py, 可将该文件直接下载下来,由于YOLO整个项目比较大,代码量比较多,所以分成几个文件,一起编写。

# -*- coding: utf-8 -*-
import os, glob
import numpy as np
import xml.etree.ElementTree as ET

"""

该文件主要用于解析xml文件,同时返回原图片的路径与标签中目标的位置信息以及类别信息

"""
def paras_annotation(img_dir, ann_dir, labels):
    """
    :param img_dir: image path
    :param ann_dir: annotation xml file path
    :param labels: ("class1", "class2",...,), 背景默认为0
    :function: paras annotation info from xml file
    :return:
    """
    imgs_info = []  #存储所有图片信息的容器列表
    max_boxes = 0   #计算所有图片中,目标在一张图片中所可能出现的最大数量
    # for each annotation xml file
    for ann in os.listdir(ann_dir):  # 遍历文件夹中所有的xml文件, 返回值是xml的地址
        tree = ET.parse(os.path.join(ann_dir, ann))  #使用xml内置函数读取xml文件,并返回一个可读取节点的句柄

        img_info = dict()  # 为每一个标签xml文件创建一个内容存放容器字典
        boxes_counter = 0  # 计算该标签文件中所含有的目标数量
        # 由于每张标签中,目标存在数量可能大于1, 所有将object内容格式设置为列表,以存放多个object
        img_info['object'] = []
        for elem in tree.iter(): # 遍历xml文件中所有的节点
            if 'filename' in elem.tag:  # 读取文件名,将文件绝对路径存储在字典中
                img_info['filename'] = os.path.join(img_dir, elem.text)
            # 读取标签中目标的宽,高, 通道默认为3不进行读取
            if 'width' in elem.tag:
                img_info['width'] = int(elem.text)
                # assert img_info['width'] == 512  #用于断言图片的宽高为512 512
            if 'height' in elem.tag:
                img_info['height'] = int(elem.text)
                # assert img_info['height'] == 512

            if 'object' in elem.tag or 'part' in elem.tag:  # 读取目标框的信息
                # 目标框信息存储方式:x1-y1-x2-y2-label
                object_info = [0, 0, 0, 0, 0] # 创建存储目标框信息的容器列表
                boxes_counter += 1
                for attr in list(elem):  # 循环读取子节点
                    if 'name' in attr.tag:  # 目标名
                        label = labels.index(attr.text) + 1 # 返回索引值 并加1, 因为背景为0
                        object_info[4] = label
                    if 'bndbox' in attr.tag:  # bndbox的信息
                        for pos in list(attr):
                            if 'xmin' in pos.tag:
                                object_info[0] = int(pos.text)
                            if 'ymin' in pos.tag:
                                object_info[1] = int(pos.text)
                            if 'xmax' in pos.tag:
                                object_info[2] = int(pos.text)
                            if 'ymax' in pos.tag:
                                object_info[3] = int(pos.text)
                # object shape: [n, 5],是一个列表,但包含n个子列表,每个子列表有5个内容
                img_info['object'].append(object_info)

        imgs_info.append(img_info)  # filename, w/h/box_info
        # (N,5)=(max_objects_num, 5)
        if boxes_counter > max_boxes:
            max_boxes = boxes_counter
    # the maximum boxes number is max_boxes
    # 将读取的object信息转化为一个矩阵形式:[b, max_objects_num, 5]
    boxes = np.zeros([len(imgs_info), max_boxes, 5])
    # print(boxes.shape)
    imgs = []  # filename list
    for i, img_info in enumerate(imgs_info):
        # [N,5]
        img_boxes = np.array(img_info['object']) # img_boxes.shape[N, 5]
        # overwrite the N boxes info
        boxes[i, :img_boxes.shape[0]] = img_boxes

        imgs.append(img_info['filename'])  # 文件名

        # print(img_info['filename'], boxes[i,:5])
    # imgs: list of image path
    # boxes: [b,40,5]
    return imgs, boxes


# 测试代码
# if __name__ == "__main__":
#     img_path = "data\\val\\image"  #图片路径
#     annotation_path = "data\\val\\annotation" # 标签路径
#     label = ("sugarbeet", "weed")  # 自定义的标签名字,背景不写,默认为0
#
#     img, box = paras_annotation(img_path, annotation_path, label)
#     print(img[0])
#     print(box.shape)
#     print(box[0])

 paras_annotation返回值imgs, boxes, 其中imgs是个列表,它包含了每张图片的路径,boxes是一个三维矩阵,它包含了每张图片的所有目标位置与类别信息,所以它的shape是[b, max_boxes, 5],b: 图片数量,max_boxes: 所有图片中最大目标数,比如图片A有3个目标,图片B有4个目标,图片C有10个目标,则最大目标数就是10;5: x_min, y_min, x_max, y_max, label(在xml中就是name)。

之所以有max_boxes这个参数设置,是为了将所有的标签文件的信息都放在一个矩阵变量中。因为每张图片的目标数必然是不一样的,如果不设置max_boxes这个参数,就无法将所有的标签文件信息合在一个矩阵变量中。如果一个图片的目标数不够max_boxes怎么办,例如图片A有3个目标,max_boxes是10,则假设图片A有10个目标,只是将后7个目标的数据全部置为0,前三个目标的数据赋值于它原本的数值,这也是开始为什么用np.zeros()初始化boxes。

3.2 读取图片

请下载文件://download.csdn.net/download/qq_37116150/12289208

该文件包含完整代码

我们训练需要的是图片的内容信息,不是路径,所以我们需要通过图片路径来读取图片,以获得图片信息,通过3.1可以获得所有训练图片的路径。

def preprocess(img, img_boxes):
    # img: string
    # img_boxes: [40,5]
    x = tf.io.read_file(img)
    x = tf.image.decode_png(x, channels=3)
    x = tf.image.convert_image_dtype(x, tf.float32) # 将数据转化为 =>[0~ 1]

    return x, img_boxes

使用tensorflow自带的读取图片函数tf.io.read_file来读取图片,不用使用for循环一个一个的读取图片,然后使用tf.image.decode_png将图片信息解码出来,如果你的训练图片是jpg,则使用tf.image.decode_jpeg来解码。tf.image.convert_image_dtype(x, tf.float32)可将数据直接归一化并将数据格式转化为tf.float32格式。

为了更加方便训练,我们需要构建一个tensorflow队列,将解码出来的图片数据与标签数据一起加载进队列中,而且通过这种方式,也可以使图片数据与标签数据一一对应,不会出现图片与标签对照絮乱的情况。

def get_datasets(img_dir, ann_dir,label,batch_size=1):
    imgs, boxes = paras_annotation(img_dir, ann_dir, label)
    db = tf.data.Dataset.from_tensor_slices((imgs, boxes))
    db = db.shuffle(1000).map(preprocess).batch(batch_size=batch_size).repeat()
    # db = db.map(preprocess).batch(batch_size=batch_size).repeat()
    return db

通过该函数也可以动态的调节训练数据集批量。

最后就是做数据增强,由于代码较多,就不再赘述,可下载文件观看。

通过3.1,3.2,我们就得到了用于训练的数据队列,该队列中包含图片数据,真实标签数据。

4. 真实标签格式处理

请下载文件://download.csdn.net/download/qq_37116150/12289213

该文件包含完整代码

4.1 单张图片 

 到了这一步,训练数据预处理算是完成了一小半,后面则是更加重要的训练数据预处理。首先,我们要明白一个问题,目标检测和目标分类是不一样的。目标分类的输出是一个二维张量[batch, num_classes],目标分类的真实标签通过热编码后也是一个二维张量,所有不需要多做处理,只做一个one-hot就可以啦。而目标检测的输出并不是一个二维张量,比如YOLOV2输出的就是五维张量                    [batch, 16, 16, 5, 25]。而我们的标签shape则是[batch, max_boxes, 5],明显真实标签shape与网络预测输出shape不一致,无法做比较,损失函数就不能完成,为了完成损失函数或者说是真实标签与网络预测输出作比较,需要修改真实标签的形状。在修改真实标签shape之前,需要了解YOLOV2的损失函数是由几部分构成的。

YOLOV2损失函数包含三部分:

  1. 坐标损失: x,y,w,h
  2. 类别损失: class,根据自己的标签设定

  3. 置信度损失: confidence, anchors与真实框的IOU

针对损失函数,需要预先准备四个变量,分别是真实标签掩码,五维张量的真实标签,转换格式的三维张量真实标签,只包含类别的五维张量。请看具体代码:

def process_true_boxes(gt_boxes, anchors):
    """
    计算一张图片的真实标签信息
    :param gt_boxes:
    :param anchors:YOLO的预设框anchors
    :return:
    """
    # gt_boxes: [40,5] 一张真实标签的位置坐标信息
    # 512//16=32
    # 计算网络模型从输入到输出的缩小比例
    scale = IMGSZ // GRIDSZ  # IMGSZ:图片尺寸512,GRIDSZ:输出尺寸16
    # [5,2] 将anchors转化为矩阵形式,一行代表一个anchors
    anchors = np.array(anchors).reshape((5, 2))

    # mask for object
    # 用来判断该方格位置的anchors有没有目标,每个方格有5个anchors
    detector_mask = np.zeros([GRIDSZ, GRIDSZ, 5, 1])
    # x-y-w-h-l
    # 在输出方格的尺寸上[16, 16, 5]制作真实标签, 用于和预测输出值做比较,计算损失值
    matching_gt_box = np.zeros([GRIDSZ, GRIDSZ, 5, 5])
    # [40,5] x1-y1-x2-y2-l => x-y-w-h-l
    # 制作一个numpy变量,用于存储一张图片真实标签转换格式后的数据
    # 将左上角与右下角坐标转化为中心坐标与宽高的形式
    # [x_min, y_min, x_max, y_max] => [x_center, y_center, w, h]
    gt_boxes_grid = np.zeros(gt_boxes.shape)
    # DB: tensor => numpy 方便计算
    gt_boxes = gt_boxes.numpy()

    f
评论 26
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

然雪

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值