Tensorflow2.0 物体识别(一)训练模型数据准备

最近研究了一下人工智能图像识别方面的知识,使用的是鼎鼎大名的tensorflow工具,同时为了避免重复造轮子,使用了谷歌已经实现的的object detection api来实现图像识别。下载地址如下:https://github.com/tensorflow/models

关于环境的安装,其实网上有很多教程,但是大部分都是基于tensorflow1.0的,基于tensorflow2.0的并不多,不过总体安装也并不算复杂,有些坑,但是不多,本章我们不做详细介绍,本章的主要目的是为了能训练自己的识别模型所需要做的训练数据准备工作。

举个例子,比如我们要准备识别鸡蛋,我们必须要准备好大量鸡蛋的图片,比如1000张,其中train训练数据900张(含验证数据),test测试数据100张,然后要对这些图像进行人工标注,我们使用 LabelImg 这款小软件,对train和test里的图片进行人工标注,如下图所示。

当数据标注好后,我们需要将图片和标注文件(xml格式)按VOC的文件格式来安放。目录结构如下图所示:

其中Annotations中存放有标注文件:

JPEGImages存放图片文件:

要注意的是标注文件和图片文件的文件名要一一对应。

最后在ImageSet的Main子目录中建立四个文本文件如下:

至此,我们数据处理的第一步就算准备好了,下一步需要通过代码生成 ImageSet/Main中的文本文件的数据,其实这步就是将所有的图片分割为训练集,验证集,测试集三个部分,并将分割完的信息写入到Main中的对应的文本文件中去,分割的代码如下:

import xml.etree.ElementTree as ET
import os
import random

import cv2 as cv


def change_image_format(old_format='.png', new_format='.jpg'):
    img_dir = "../VOC2012/JPEGImages/"

    files = os.listdir(img_dir)
    for img_file in files:
        if os.path.isfile(os.path.join(img_dir, img_file)):
            image_path = os.path.join(img_dir, img_file)
            # print(image_path)

            image = cv.imread(image_path)
            new_image_path = image_path.replace(old_format, new_format)
            cv.imwrite(new_image_path, image, [cv.IMWRITE_JPEG_QUALITY, 100])

            print("processed image : %s" % (new_image_path))


def xml_modification():
    ann_dir = "../VOC2012/Annotations/"
    img_dir = "E:/Project_Python/TensorflowTest/models/research/object_detection/ssd_model/VOCdevkit/VOC2012/JPEGImages/" #改为自己的目录
    files = os.listdir(ann_dir)
    for xml_file in files:
        if os.path.isfile(os.path.join(ann_dir, xml_file)):
            xml_path = os.path.join(ann_dir, xml_file)
            # print(xml_path)

            tree = ET.parse(xml_path)
            root = tree.getroot()

            for elem in root.iter('folder'):
                elem.text = 'voc2012'

            for elem in root.iter('filename'):
                pass

            for elem in root.iter('path'):
                path = elem.text
                filename = path.split('/')[-1]
                new_path = img_dir + filename
                elem.text = new_path

            tree.write(xml_path)

            print("processed xml : %s" % (xml_path))



def generate_train_val_test_txt():
    xml_file_path = "../VOC2012/Annotations/"
    save_Path = "../VOC2012/ImageSets/Main/"

    trainval_percent = 0.9
    train_percent = 0.9

    total_xml = os.listdir(xml_file_path)
    num = len(total_xml)
    list = range(num)

    tv = int(num * trainval_percent)
    tr = int(tv * train_percent)
    trainval = random.sample(list, tv)
    train = random.sample(trainval, tr)

    print("train and val size", tv)
    print("train size", tr)

    ftrainval = open(os.path.join(save_Path, 'trainval.txt'), 'w')
    ftest = open(os.path.join(save_Path, 'test.txt'), 'w')
    ftrain = open(os.path.join(save_Path, 'train.txt'), 'w')
    fval = open(os.path.join(save_Path, 'val.txt'), 'w')

    for i in list:
        name = total_xml[i][:-4] + '\n'
        if i in trainval:
            ftrainval.write(name)
            if i in train:
                ftrain.write(name)
            else:
                fval.write(name)
        else:
            ftest.write(name)

    ftrainval.close()
    ftrain.close()
    fval.close()
    ftest.close()


xml_modification()
generate_train_val_test_txt()

完成了这一步后,后面就要做最后一部的转换了,对于tensorflow的图像识别,训练数据都是要求为tfrecord的格式,网上的大部分教程,由于是基于tensorflow1.0,所以格式转换都要通过csv文件中转,操作起来比较麻烦,好在tesnorflow2.0后,object detection api已经为我们提供了直接转换的工具,简单的两条命令就可以完成转换:

Python object_detection/dataset_tools/create_pascal_tf_record.py --label_map_path=object_detection/ssdv2/pascal_label_map.pbtxt --data_dir=object_detection/ssdv2/VOCdevkit --year=VOC2012 --set=train --output_path=object_detection/ssdv2/egg_train.record

Python object_detection/dataset_tools/create_pascal_tf_record.py --label_map_path=object_detection/ssdv2/pascal_label_map.pbtxt --data_dir=object_detection/ssdv2/VOCdevkit --year=VOC2012 --set=val --output_path=object_detection/ssdv2/egg_val.record

这里的pascal_label_map.pbtxt文件需要提前准备好,里面是标注结果的记录。

当您看见两个record文件生成后,恭喜,训练数据就算完全准备好了,下一步走起!

  • 2
    点赞
  • 35
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值