keras_retinanet 目标检测——自定义图片数据集的模型训练步骤

本文详细介绍了使用keras_retinanet库训练自定义目标检测模型的步骤,包括环境配置、数据集准备、模型训练和目标检测。通过调整参数和处理数据,最终成功训练并应用模型进行目标检测。
摘要由CSDN通过智能技术生成

最近在学习 keras_retinanet ,下面就记录下用自己的数据集进行的模型训练。

大致分为以下几步:

  • 自定义训练数据
  • 图片目标标注
  • 生成用于训练的图片名称、目标标注位置及目标类别的.csv文件
  • 开始训练模型(注意参数调整)
  • 转换训练好的模型
  • 用转换后的模型进行目标检测

下面就一步一步介绍吧:

目录

1.下载包,安装环境。

2.准备数据集

3.训练模型

4.目标检测


由于我之前已经安装了vs2015、Anaconda和Pycharm,就不在此赘述了。

本机配置:

  • GTX 1060 3G       
  • Win10 64              
  • vs2015                    
  • Anaconda3 5.1.0  
  • Pycharm                  

1.下载包,安装环境。

  • 从Github上下载Github ——> keras-retinanet这个仓库
  • 确保你的环境里有tensorflow、numpy、keras
  • 切换到当前目录下运行
    pip install . --user

    或者直接从克隆的仓库中运行代码,但是需要运行python setup.py build_ext --inplace来首先编译Cython代码。

  • 如果还不行就把keras_retinanet这整个目录拷贝到你自己环境的D:\Anaconda3-5.0.1\envs\tf-gpu\Lib\site-packages下

测试准备:点击测试模型下载,将用于测试的模型resnet50_coco_best_v2.1.0.h5放在snapshots目录下

下面就可以在jupyter notebook运行examples里的ResNet50RetinaNet.ipynb进行测试,当然也可以在jupyter notebook中将文件保存成.py格式的在pycharm里运行。

2.准备数据集

  • retinanet模型训练的数据是按照VOC2007格式处理的,所以你也需要将自己的数据准备成VOC2007的格式
  1. 准备如图三个文件夹,JPEGImages放你自己准备训练模型的图片,图片名称最好是按1.jpg,2.jpg这种类型;
  2. Annotations放图片目标的位置和类型标注的.xml文件,这个可以用Github——>labellmg里的生成目标标注的工具自动生成.xml的文件(注意使用时将保存路径改到Annotations文件夹);
  3. ImageSets里的子文件夹Main里放按比例随机抽样切分的训练集、验证集、测试集样本下标值的txt文件,这个可以用如下gen_main_txt.py自动生成
    import os
    import random
    
    trainval_percent = 0.8  # 自定义用于训练模型的数据(训练数据和交叉验证数据之和)占全部数据的比例
    train_percent = 0.8  # 自定义训练数据占训练数据交叉验证数据之和的比例
    xmlfilepath = 'Annotations'
    txtsavepath = 'ImageSets\Main'
    total_xml = os.listdir(xmlfilepath)
    
    num = len(total_xml)
    list = range(num)
    tv = int(num * trainval_percent)
    tr = int(tv * train_percent)
    trainval = random.sample(list, tv)
    train = random.sample(trainval, tr)
    
    ftrainval = open('ImageSets/Main/trainval.txt', 'w')
    ftest = open('ImageSets/Main/test.txt', 'w')
    ftrain = open('ImageSets/Main/train.txt', 'w')
    fval = open('ImageSets/Main/val.txt', 'w')
    
    for i in list:
        name = total_xml[i][:-4]+'\n'
        if i in trainval:
            ftrainval.write(name)
            if i in train:
                ftrain.write(name)
            else:
                fval.write(name)
        else:
            ftest.write(name)
    
    ftrainval.close()
    ftrain.close()
    fval.close()
    ftest .close()

     

  4. 生成包含文件名、目标位置、目标类型的annotations.csv文件和类别标签的classes.csv文件,这个可以用gen_csv.py自动生成
    import csv
    import os
    import glob
    import sys
    
    class PascalVOC2CSV(object):
        def __init__(self, xml=[], ann_path='./annotations.csv', classes_path='./classes.csv'):
            '''
            :param xml: 所有Pascal VOC的xml文件路径组成的列表
            :param ann_path: ann_path
            :param classes_path: classes_path
            '''
            self.xml = xml
            self.ann_path = ann_path
            self.classes_path = classes_path
            self.label = []
            self.annotations = []
    
            self.data_transfer()
            self.write_file()
    
        def data_transfer(self):
            for num, xml_file in enumerate(self.xml):
                try:
                    # print(xml_file)
                    # 进度输出
                    sys.stdout.write('\r>> Converting image %d/%d' % (
                        num + 1, len(self.xml)))
                    sys.stdout.flush()
    
                    with open(xml_file, 'r') as fp:
                        for p in fp:
                            if '<filename>' in p:
                                self.filen_ame = p.split('>')[1].split('<')[0]
    
                            if '<object>' in p:
                                # 类别
                                d = [next(fp).split('>')[1].split('<')[0] for _ in range(9)]
                                self.supercategory = d[0]
                                if self.supercategory not in self.label:
                                    self.label.append(self.supercategory)
    
                                # 边界框
                                x1 = int(d[-4]);
                                y1 = int(d[-3]);
                                x2 = int(d[-2]);
                                y2 = int(d[-1])
    
                                self.annotations.append(
                                    [os.path.join('JPEGImages', self.filen_ame), x1, y1, x2, y2, self.supercategory])
                except:
                    continue
    
            sys.stdout.write('\n')
            sys.stdout.flush()
    
        def write_file(self, ):
            with open(self.ann_path, 'w', newline='') as fp:
                csv_writer = csv.writer(fp, dialect='excel')
                csv_writer.writerows(self.annotations)
    
            class_name = sorted(self.label)
            class_ = []
            for num, name in enumerate(class_name):
                class_.append([name, num])
            with open(self.classes_path, 'w', newline='') as fp:
                csv_writer = csv.writer(fp, dialect='excel')
                csv_writer.writerows(class_)
    
    if __name__ == "__main__":
        xml_file = glob.glob('./Annotations/*.xml')
        PascalVOC2CSV(xml_file)

     

  5. 用keras-retinanet-master\keras_retinanet\bin目录下的debug.py测试数据集是否生成成功,这个程序需要在命令行执行
    python D:/PyCharm/PycharmProjects/tf-gpu-env/project/keras-retinanet-master/keras_retinanet/bin/debug.py csv  D:/PyCharm/PycharmProjects/tf-gpu-env/project/keras-retinanet-master/examples/annotations.csv  D:/PyCharm/PycharmProjects/tf-gpu-env/project/keras-retinanet-master/examples/classes.csv

    弹出你标注的图片说明数据准备成功。

3.训练模型

  • 用keras-retinanet-master\keras_retinanet\bin目录下的train.py来训练模型,需要修改文件里import的相对路径
    from keras_retinanet import layers  # noqa: F401
    from keras_retinanet import losses
    from keras_retinanet import models
    from keras_retinanet.callbacks import RedirectModel
    from keras_retinanet.callbacks.eval import Evaluate
    from keras_retinanet.models.retinanet import retinanet_bbox
    from keras_retinanet.preprocessing.csv_generator import CSVGenerator
    from keras_retinanet.preprocessing.kitti import KittiGenerator
    from keras_retinanet.preprocessing.open_images import OpenImagesGenerator
    from keras_retinanet.preprocessing.pascal_voc import PascalVocGenerator
    from keras_retinanet.utils.anchors import make_shapes_callback
    from keras_retinanet.utils.config import read_config_file, parse_anchor_parameters
    from keras_retinanet.utils.keras_version import check_keras_version
    from keras_retinanet.utils.model import freeze as freeze_model
    from keras_retinanet.utils.transform import random_transform_generator

     

  • 根据你自己GPU的性能微调参数,如果报tensorflow.python.framework.errors_impl.ResourceExhaustedError: OOM when allocating tensor with shape[1,256,100,100]错误说明你的GPU内存不够用,可以通过降低batch-size、将image-min-side,image-max-side改小、改小网络结构等方式解决
    tensorflow.python.framework.errors_impl.ResourceExhaustedError: OOM when allocating tensor with shape[1,256,100,100] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
             [[{
        {node training/Adam/gradients/classification_submodel/pyramid_classification_3/convolution_grad/Conv2DBackpropInput}} = Conv2DBackpropInput[T=DT_FLOAT, _class=["loc:@training/Adam/cond_85/Switch_2"], data_format="NCHW", dilations=[1, 1, 1, 1], padding="SAME", strides=[1, 1, 1, 1], use_cudnn_on_gpu=true, _device="/job:localhost/replica:0/task:0/device:GPU:0"](training/Adam/gradients/classification_submodel/pyramid_classification_3/convolution_grad/ShapeN, pyramid_classification_3/kernel/read, training/Adam/gradients/classification_submodel/pyramid_classification_3/Relu_grad/ReluGrad)]]
    Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.

    参数调整:

  • 修改workers参数(不知道为什么多线程会出错,就暂且不开启吧)

  • 在命令行运行train.py文件
    python D:/PyCharm/PycharmProjects/tf-gpu-env/project/keras-retinanet-master/keras_retinanet/bin/train.py csv  D:/PyCharm/PycharmProjects/tf-gpu-env/project/keras-retinanet-master/examples/annotations.csv  D:/PyCharm/PycharmProjects/tf-gpu-env/project/keras-retinanet-master/examples/classes.csv

    创建模型,网络结构如下 

Creating model, this may take a second...
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            (None, None, None, 3 0
__________________________________________________________________________________________________
padding_conv1 (ZeroPadding2D)   (None, None, None, 3 0           input_1[0][0]
__________________________________________________________________________________________________
conv1 (Conv2D)                  (None, None, None, 6 9408        padding_conv1[0][0]
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization)   (None, None, None, 6 256         conv1[0][0]
__________________________________________________________________________________________________
conv1_relu (Activation)         (None, None, None, 6 0           bn_conv1[0][0]
__________________________________________________________________________________________________
pool1 (MaxPooling2D)            (None, None, None, 6 0           conv1_relu[0][0]
__________________________________________________________________________________________________
res2a_branch2a (Conv2D)         (None, None, None, 6 4096        pool1[0][0]
__________________________________________________________________________________________________
bn2a_branch2a (BatchNormalizati (None, None, None, 6 256         res2a_branch2a[0][0]
__________________________________________________________________________________________________
res2a_branch2a_relu (Activation (None, None, None, 6 0           bn2a_branch2a[0][0]
__________________________________________________________________________________________________
padding2a_branch2b (ZeroPadding (None, None, None, 6 0           res2a_branch2a_relu[0][0]
__________________________________________________________________________________________________
res2a_branch2b (Conv2D)         (None, None, None, 6 36864       padding2a_branch2b[0][0]
__________________________________________________________________________________________________
bn2a_branch2b (BatchNormalizati (None, None, None, 6 256         res2a_branch2b[0][0]
__________________________________________________________________________________________________
res2a_branch2b_relu (Activation (None, None, None, 6 0           bn2a_branch2b[0][0]
____________________________________________________________________________________________
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值