基于TensorFlow2.0的YOLOV2训练过程

目录

 1. 环境配置

1.1 Anaconda安装

1.2 Pycharm安装

1.3 TensorFlow安装

2. 训练数据集准备

2.1 数据集标注

3. 训练数据集预处理

3.1 解析标签文件XML

3.2 读取图片

4. 真实标签格式处理

4.1 单张图片 

4.2 批量图片

5. 模型搭建与权重初始化

5.1 模型搭建

5.2 权重初始化

6. 损失计算

6.1 制作网格坐标

6.2 坐标损失计算

6.3 类别损失计算

6.4 置信度计算

7. 模型训练与保存

8. 模型验证


运行环境:

Python:3.6

TensorFlow: 2.0.0+

cuda: 10.0

cundnn: 7.4

Pycharm: 发行版

 1. 环境配置

1.1 Anaconda安装

我使用的是Windows系统,当然,使用Ubuntu也可以,没有什么区别。

下载Anaconda3,下载链接:https://pan.baidu.com/s/1xzrb7kqigl5SYigVO2NdWw,提取码:41tg 

将Anaconda3下载完成后,然后安装。

1.2 Pycharm安装

下载Pycharm, 下载链接:https://pan.baidu.com/s/1SOhs72JK9YY6GAFImrwdBQ,提取码:bqsn

将Pycharm下载完成后,然后安装

1.3 TensorFlow安装

1. 创建一个Python虚拟环境,使用Anaconda Prompt 或者 Anaconda Navigator都可以,我使用的是Prompt, ubuntu系统可以使用终端或者Navigator。

conda create -n Tensorflow-GPU python=3.6

环境的名字可以任意选择。

2. 激活环境,在该环境中安装TensorFlow2.0,我这里介绍一种简单的方法。

conda install tensorflow-gpu==2.0.0 #gpu版本

# conda install tensorflow==2.0.0 #cpu版本

 通过该命令会将TensorFlow-gpu版本自动安装成功,包含配套的cuda, cudnn。在ubuntu上一样的命令,如果安装失败,一般都是因为网速的问题,可以考虑将conda的源换为国内源,这里就不再多赘述,CSDN中有很多博客介绍。

3. 打开Pycharm配置环境即可。

2. 训练数据集准备

目标检测数据集一般是VOC格式的,YOLO与SSD都是这种格式。

2.1 数据集标注

1. 首先将采集好的原图,全部resize成网络输入的尺寸,比如YOLOV2的输入尺寸是512X512。

# -*- coding: utf-8 -*-
import cv2
import os

def rebuild(path_src, path_dst, width, height):
    """
    :param path_src: 原图相对地址
    :param path_dst: 保存图相对地址
    :return: None
    """
    i = 1
    image_names = os.listdir(path_src)
    for image in image_names:
        if image.endswith('.jpg') or image.endswith('.png'):
            img_path = path_src + image
            save_path = path_dst + image
            img = cv2.imread(img_path)
            resize_img = cv2.resize(img, (width, height))
            cv2.imwrite(save_path, resize_img)
            print("修改第 " + str(i), " 张图片:", save_path)
            i = i + 1

if __name__ == "__main__":
    # 原图相对地址,也可以使用绝对地址
    path_src = "pikachu/"
    # 保存图相对地址,也可以使用绝对地址
    path_dst = "pikachu_new/"
    width = 512
    heght = 512
    rebuild(path_src, path_dst, width, heght)

 2. 使用labelImg进行目标标注,使用别的标注工具也可以

labelImg安装方法1:直接下载软件,然后放在桌面双击打开即可,不需要安装

链接:https://pan.baidu.com/s/1_wdd_tChBCrfcicKC-Nxgg 提取码:tsz7

labelImg安装方法2:去github下载源码编译, github链接:https://github.com/tzutalin/labelImg

3. 训练数据集预处理

3.1 解析标签文件XML

请下载文件://download.csdn.net/download/qq_37116150/12289197

该文件包含完整代码

每张图片的标签信息全部保存在.xml(使用labelImg标注图片生成的文件)文件中,标签文件中包含原图路径,原图名,目标位置信息(左上角坐标,右下角坐标,够成一个矩形框),类别名,我们需要的是原图路径, 目标位置信息以及类别名,所有我们需要将这些信息从xml标签文件中提取出来。

xml_parse.py, 可将该文件直接下载下来,由于YOLO整个项目比较大,代码量比较多,所以分成几个文件,一起编写。

# -*- coding: utf-8 -*-
import os, glob
import numpy as np
import xml.etree.ElementTree as ET

"""

该文件主要用于解析xml文件,同时返回原图片的路径与标签中目标的位置信息以及类别信息

"""
def paras_annotation(img_dir, ann_dir, labels):
    """
    :param img_dir: image path
    :param ann_dir: annotation xml file path
    :param labels: ("class1", "class2",...,), 背景默认为0
    :function: paras annotation info from xml file
    :return:
    """
    imgs_info = []  #存储所有图片信息的容器列表
    max_boxes = 0   #计算所有图片中,目标在一张图片中所可能出现的最大数量
    # for each annotation xml file
    for ann in os.listdir(ann_dir):  # 遍历文件夹中所有的xml文件, 返回值是xml的地址
        tree = ET.parse(os.path.join(ann_dir, ann))  #使用xml内置函数读取xml文件,并返回一个可读取节点的句柄

        img_info = dict()  # 为每一个标签xml文件创建一个内容存放容器字典
        boxes_counter = 0  # 计算该标签文件中所含有的目标数量
        # 由于每张标签中,目标存在数量可能大于1, 所有将object内容格式设置为列表,以存放多个object
        img_info['object'] = []
        for elem in tree.iter(): # 遍历xml文件中所有的节点
            if 'filename' in elem.tag:  # 读取文件名,将文件绝对路径存储在字典中
                img_info['filename'] = os.path.join(img_dir, elem.text)
            # 读取标签中目标的宽,高, 通道默认为3不进行读取
            if 'width' in elem.tag:
                img_info['width'] = int(elem.text)
                # assert img_info['width'] == 512  #用于断言图片的宽高为512 512
            if 'height' in elem.tag:
                img_info['height'] = int(elem.text)
                # assert img_info['height'] == 512

            if 'object' in elem.tag or 'part' in elem.tag:  # 读取目标框的信息
                # 目标框信息存储方式:x1-y1-x2-y2-label
                object_info = [0, 0, 0, 0, 0] # 创建存储目标框信息的容器列表
                boxes_counter += 1
                for attr in list(elem):  # 循环读取子节点
                    if 'name' in attr.tag:  # 目标名
                        label = labels.index(attr.text) + 1 # 返回索引值 并加1, 因为背景为0
                        object_info[4] = label
                    if 'bndbox' in attr.tag:  # bndbox的信息
                        for pos in list(attr):
                            if 'xmin' in pos.tag:
                                object_info[0] = int(pos.text)
                            if 'ymin' in pos.tag:
                                object_info[1] = int(pos.text)
                            if 'xmax' in pos.tag:
                                object_info[2] = int(pos.text)
                            if 'ymax' in pos.tag:
                                object_info[3] = int(pos.text)
                # object shape: [n, 5],是一个列表,但包含n个子列表,每个子列表有5个内容
                img_info['object'].append(object_info)

        imgs_info.append(img_info)  # filename, w/h/box_info
        # (N,5)=(max_objects_num, 5)
        if boxes_counter > max_boxes:
            max_boxes = boxes_counter
    # the maximum boxes number is max_boxes
    # 将读取的object信息转化为一个矩阵形式:[b, max_objects_num, 5]
    boxes = np.zeros([len(imgs_info), max_boxes, 5])
    # print(boxes.shape)
    imgs = []  # filename list
    for i, img_info in enumerate(imgs_info):
        # [N,5]
        img_boxes = np.array(img_info['object']) # img_boxes.shape[N, 5]
        # overwrite the N boxes info
        boxes[i, :img_boxes.shape[0]] = img_boxes

        imgs.append(img_info['filename'])  # 文件名

        # print(img_info['filename'], boxes[i,:5])
    # imgs: list of image path
    # boxes: [b,40,5]
    return imgs, boxes


# 测试代码
# if __name__ == "__main__":
#     img_path = "data\\val\\image"  #图片路径
#     annotation_path = "data\\val\\annotation" # 标签路径
#     label = ("sugarbeet", "weed")  # 自定义的标签名字,背景不写,默认为0
#
#     img, box = paras_annotation(img_path, annotation_path, label)
#     print(img[0])
#     print(box.shape)
#     print(box[0])

 paras_annotation返回值imgs, boxes, 其中imgs是个列表,它包含了每张图片的路径,boxes是一个三维矩阵,它包含了每张图片的所有目标位置与类别信息,所以它的shape是[b, max_boxes, 5],b: 图片数量,max_boxes: 所有图片中最大目标数,比如图片A有3个目标,图片B有4个目标,图片C有10个目标,则最大目标数就是10;5: x_min, y_min, x_max, y_max, label(在xml中就是name)。

之所以有max_boxes这个参数设置,是为了将所有的标签文件的信息都放在一个矩阵变量中。因为每张图片的目标数必然是不一样的,如果不设置max_boxes这个参数,就无法将所有的标签文件信息合在一个矩阵变量中。如果一个图片的目标数不够max_boxes怎么办,例如图片A有3个目标,max_boxes是10,则假设图片A有10个目标,只是将后7个目标的数据全部置为0,前三个目标的数据赋值于它原本的数值,这也是开始为什么用np.zeros()初始化boxes。

  • 14
    点赞
  • 36
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 26
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 26
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

然雪

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值