亲测:使用tensorflow的API训练ssd模型

从零开始到最后成功的操作过程:
(1)首先下载models,本地路径结构为~/tensorflow/models,其中~表示本地的home路径,然后在models中的research路径下执行下列两条命令
protoc object_detection/protos/*.proto --python_out=.
export PYTHONPATH="${PYTHONPATH}:~/tensorflow/models:~/tensorflow/models/research/slim/"
(这条命令在每次新起一个终端的时候都需要运行一下,或者也可以把这条命令写到系统中 :
sudo gedit ~/.bashrc 打开.bashrc将这条命令添加到最后,再执行如下命令
source ~/.bashrc)

(2)之后也是在research下,执行如下两句命令:
python setup.py build
python setup.py install

(3)过程中出现找不到模块pycocotools的,按照如下操作
git clone https://github.com/pdollar/coco
cd ./coco/PythonAPI
make -j8
将pycocotools文件夹复制到research路径下

(4)建立自己的训练路径,其中包括以VOC数据格式为标准的数据文件,比如在home路径下建立如下路径结构
~/ssd_train
~/ssd_train/train_data
~/ssd_train/train_data/Annotations (使用labelImg生成的xml标注文档)
~/ssd_train/train_data/JPEGImages (以jpg为格式的图像,这个格式要求是对应到代码中的,修改了代码中的格式要求的话这里也可以不要求jpg格式)
~/ssd_train/train_data/ImageSets
~/ssd_train/train_data/ImageSets/Main (该路径下存储的是训练数据的划分txt,这个txt可以使用脚本来生成,参加后续相关内容)
~/ssd_train/raining (该路径是最后生成的模型存储的路径,后面训练命令设置的)
~/ssd_train/test (该路径用于测试模型使用)

(5)配置标签文件和训练配置文件
将research下object_detection/data/pascal_label_map.pbtxt 拷贝到刚才建立的~/ssd_train下,并且对其按照自己的需求进行修改,这个是标签说明文档,该文档的命名也是可以修改的,在下个步骤配置文件修改中写上对应的标签文档路径和名字就而已了。
将research下object_detection/samples/configs/ssd_mobilenet_v1_pets.config 拷贝到刚才建立的~/ssd_train路径下,该配置文件是需要自己在里面做一些修改的,比如数据路径、标签文档要求等就可以了,同时在该配置文档中可以设置batchsize等参数。

(6)关于数据准备:
从xml开始,如何去制作训练可以使用的数据?
解答:
第一步,需要将xml文件和图像文件进行分块,分成哪些是训练用的,哪些是验证用的,可以使用脚本来实现,参见如下代码:
其中的mainsets_p = ~/ssd_train/train_data/ImageSets/Main
annotations_p = ~/ssd_train/train_data/Annotations
执行一下代码后train和val的txt就生成了,
注意如果需要分测试数据的话需要自己操作吧,或者在执行这个代码之前先把测试用数据拿出来。

import params_file
import os
import random
val_partion = 0.1  # 这里是设置验证图像所占的百分比


def get_train_val_image_list():
    mainsets_p = params_file.mainsets_p
    train_file = open(os.path.join(mainsets_p, 'train.txt'), 'w+')
    val_file = open(os.path.join(mainsets_p, 'val.txt'), 'w+')
    val_num = 0
    for root, dir, files in os.walk(params_file.annotations_p):
        if files is not None:
            image_num = len(files)
            val_num = int(image_num * val_partion)
            print('image_num = ', image_num)
            print('val_num = ', val_num)
            print('train_num = ', image_num - val_num)
            break
        else:
            print('the dataset is empty!')

    val_index = random.sample(range(0, image_num-1), val_num)
    for idx, file_name in enumerate(files):
        image_name = file_name.replace('.xml', '')
        if idx in val_index:
            val_file.write(image_name + '\n')
        else:
            train_file.write(image_name + '\n')

    val_file.close()
    train_file.close()


if __name__ == "__main__":
    get_train_val_image_list()

第二步,根据上面的分配,使用相关图像和xml分别生成train.record和val.record, 这里需要自己根据自己的需要对生成record的代码进行一些修改,比如路径对应啊之类的,
给的例子代码里面是以VOC数据集来的,还有用哪年哪年的选择,这里如果是自己的数据集的话就可以不需要这些东西了,直接删掉相应的判断和路径生成就行。
可以将这些过程写sh脚本文件运行,也可以运行代码,或者在python中以命令的方式运行。
追加修改后的代码:

def dict_to_tf_example(data,
                       dataset_directory,
                       label_map_dict,
                       ignore_difficult_instances=False,
                       image_subdirectory='JPEGImages'):
    """Convert XML derived dict to tf.Example proto.

    Notice that this function normalizes the bounding box coordinates provided
    by the raw data.

    Args:
      data: dict holding PASCAL XML fields for a single image (obtained by
        running dataset_util.recursive_parse_xml_to_dict)
      dataset_directory: Path to root directory holding PASCAL dataset
      label_map_dict: A map from string label names to integers ids.
      ignore_difficult_instances: Whether to skip difficult instances in the
        dataset  (default: False).
      image_subdirectory: String specifying subdirectory within the
        PASCAL dataset directory holding the actual image data.

    Returns:
      example: The converted tf.Example.

    Raises:
      ValueError: if the image pointed to by data['filename'] is not a valid JPEG
    """
    filename = data['filename'].replace('.png', '.jpg')
    filename = filename.replace(' ', '')
    filename = filename.replace('.bmp', '.jpg')
    img_path = os.path.join(image_subdirectory, filename)
    full_path = os.path.join(dataset_directory, img_path)
    with tf.gfile.GFile(full_path, 'rb') as fid:
        encoded_jpg = fid.read()
    encoded_jpg_io = io.BytesIO(encoded_jpg)
    image = PIL.Image.open(encoded_jpg_io)
    if image.format != 'JPEG':
        raise ValueError('Image format not JPEG')
    key = hashlib.sha256(encoded_jpg).hexdigest()

    width = int(data['size']['width'])
    height = int(data['size']['height'])

    xmin = []
    ymin = []
    xmax = []
    ymax = []
    classes = []
    classes_text = []
    truncated = []
    poses = []
    difficult_obj = []
    if 'object' in data:
        for obj in data['object']:
            difficult = bool(int(obj['difficult']))
            if ignore_difficult_instances and difficult:
                continue
            if not obj['name'] in label_map_dict:
                continue
            difficult_obj.append(int(difficult))

            xmin.append(float(obj['bndbox']['xmin']) / width)
            ymin.append(float(obj['bndbox']['ymin']) / height)
            xmax.append(float(obj['bndbox']['xmax']) / width)
            ymax.append(float(obj['bndbox']['ymax']) / height)
            classes_text.append(obj['name'].encode('utf8'))
            classes.append(label_map_dict[obj['name']])
            truncated.append(int(obj['truncated']))
            poses.append(obj['pose'].encode('utf8'))

    example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': dataset_util.int64_feature(height),
        'image/width': dataset_util.int64_feature(width),
        'image/filename': dataset_util.bytes_feature(
            data['filename'].encode('utf8')),
        'image/source_id': dataset_util.bytes_feature(
            data['filename'].encode('utf8')),
        'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
        'image/encoded': dataset_util.bytes_feature(encoded_jpg),
        'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
        'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
        'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
        'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
        'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
        'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
        'image/object/class/label': dataset_util.int64_list_feature(classes),
        'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
        'image/object/truncated': dataset_util.int64_list_feature(truncated),
        'image/object/view': dataset_util.bytes_list_feature(poses),
    }))
    return example


def main(_):
    if FLAGS.set not in SETS:
        raise ValueError('set must be in : {}'.format(SETS))

    data_dir = FLAGS.data_dir

    writer = tf.python_io.TFRecordWriter(FLAGS.output_path)

    label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)

    examples_path = os.path.join(data_dir, 'ImageSets', 'Main', FLAGS.set + '.txt')
    annotations_dir = os.path.join(data_dir, FLAGS.annotations_dir)
    examples_list = dataset_util.read_examples_list(examples_path)
    for idx, example in enumerate(examples_list):
        if idx % 100 == 0:
            logging.info('On image %d of %d', idx, len(examples_list))
        path = os.path.join(annotations_dir, example + '.xml')
        with tf.gfile.GFile(path, 'r') as fid:
            xml_str = fid.read()
        xml = etree.fromstring(xml_str)
        data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']

        tf_example = dict_to_tf_example(data, FLAGS.data_dir, label_map_dict,
                                        FLAGS.ignore_difficult_instances)
        writer.write(tf_example.SerializeToString())

    writer.close()


if __name__ == '__main__':
    tf.app.run()

(7)训练执行命令
python object_detection/model_main.py --pipeline_config_path=~/ssd_train/ssd_mobilenet_v1_pascal.config --model_dir=~/ssd_train/training --num_train_steps=10000 --num_eval_steps=20

(增加)可以在TensorBoard中查看训练的进程
:~$ source activate root
:~$ tensorboard --logdir=~/ssd_train/ # 这里的logdir应该是训练的路径,里面包含了训练使用的数据和中间的训练结果相关信息

(8)生成可以用的模型,在research路径下执行如下命令
python object_detection/export_inference_graph.py --input_type image_tensor --pipeline_config_path ~/ssd_train/ssd_mobilenet_v1_pascal.config --trained_checkpoint_prefix ~/ssd_train/training/model.ckpt-100000 --output_directory ~/ssd_train/test/
(9)测试代码
这里,注意需要将标签文档的扩展名改为txt

import cv2
import numpy as np
import tensorflow as tf
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util


class TOD(object):
    def __init__(self):
        self.PATH_TO_CKPT = './frozen_inference_graph.pb'
        self.PATH_TO_LABELS = './pascal_label_map.txt'
        self.NUM_CLASSES = 20
        self.detection_graph = self._load_model()
        self.category_index = self._load_label_map()

    def _load_model(self):
        detection_graph = tf.Graph()
        with detection_graph.as_default():
            od_graph_def = tf.GraphDef()
            with tf.gfile.GFile(self.PATH_TO_CKPT, 'rb') as fid:
                serialized_graph = fid.read()
                od_graph_def.ParseFromString(serialized_graph)
                tf.import_graph_def(od_graph_def, name='')
        return detection_graph

    def _load_label_map(self):
        label_map = label_map_util.load_labelmap(self.PATH_TO_LABELS)
        categories = label_map_util.convert_label_map_to_categories(label_map,
                                                                    max_num_classes=self.NUM_CLASSES,
                                                                    use_display_name=True)
        category_index = label_map_util.create_category_index(categories)
        return category_index

    def detect(self, image):
        with self.detection_graph.as_default():
            with tf.Session(graph=self.detection_graph) as sess:
                # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
                image_np_expanded = np.expand_dims(image, axis=0)
                image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
                boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
                scores = self.detection_graph.get_tensor_by_name('detection_scores:0')
                classes = self.detection_graph.get_tensor_by_name('detection_classes:0')
                num_detections = self.detection_graph.get_tensor_by_name('num_detections:0')
                # Actual detection.
                (boxes, scores, classes, num_detections) = sess.run(
                    [boxes, scores, classes, num_detections],
                    feed_dict={image_tensor: image_np_expanded})
                # Visualization of the results of a detection.
                vis_util.visualize_boxes_and_labels_on_image_array(
                    image,
                    np.squeeze(boxes),
                    np.squeeze(classes).astype(np.int32),
                    np.squeeze(scores),
                    self.category_index,
                    use_normalized_coordinates=True,
                    line_thickness=8)

        cv2.namedWindow("detection", cv2.WINDOW_NORMAL)
        cv2.imshow("detection", image)
        cv2.waitKey(0)


if __name__ == '__main__':
    image = cv2.imread('./image/000015.jpg')
    detecotr = TOD()
    detecotr.detect(image)
    vis_util.save_image_array_as_png(image, './result/000015.png')

测试结果如下所示:
在这里插入图片描述

参考网址:
最后,在操作过程中参考了如下两个网址,对他们表示感谢
https://blog.csdn.net/chenmaolin88/article/details/79357263 全面

https://cloud.tencent.com/developer/article/1341546 使用API训练的配置参照

  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 35
    评论
评论 35
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值