tensorflow object detection学习

姓名:Jyx
描述:人工智能学习笔记

tensorflow object detection api

tensorflow object detection api 并不在主项目内,而是在单独开发的一个项目内,其路径为https://github.com/tensorflow/models/tree/master/research/object_detection

tensorflow object detection 基本流程

  1. 数据准备。将数据打包成object detection支持的格式,一般为tfrecord
  2. 参数配置。定义pipeline config文件
  3. 模型训练。调用模型训练
  4. 可视化

从上面的过程可以看出:tensorflow api封装的极为完善。基本只要准备好数据,在pipeline里面定义好模型参数就可以使用。

###1. 数据准备
object detection 框架接收一个tfrecord格式的文件作为输入,而一般的数据集都是图片和标记单独文件给出的,所以有个格式转换的过程。
tfrecord是一种二进制文件,由数据对象序列化而成。数据对象的格式需要为tf.train.Example, 其成员为tf.train.Feature. 所以数据准备就是:

  1. 准备数据字典,并转换为Feature存储,字典成员可以参考https://github.com/tensorflow/models/blob/master/research/object_detection/core/standard_fields.py中的定义
  2. 通过tf.train.Example转换为Example格式
  3. Example序列化并通过tf.python_io.TFRecordWriter写入文件

代码简单实现如下


import tensorflow as tf
import PIL.Image as im
import numpy as np
import sys,os,io
import xmltodict
import matplotlib.pyplot as plt
import random

%matplotlib inline

dpath = '../../data'
image_dir = os.path.join(dpath, 'images')
annotation_dir = os.path.join(dpath, 'annotations', 'xmls')

tfrecord_dir = 'data'
train_file_name = os.path.join(tfrecord_dir, 'train.record')

annotation_names = os.listdir(annotation_dir)
#在例子中为了方便只是随机选取了10个文件作为input
annotation_names = random.choices(annotation_names, k = 10)
image_names = [ os.path.splitext(annotation_name)[0] + '.jpg' for annotation_name in annotation_names ]

image_paths = [ os.path.join(image_dir, image_name) for image_name in image_names ]
annotation_paths = [ os.path.join(annotation_dir, annotation_name) for annotation_name in annotation_names ]

labels_map = {
    'cat': 1,
    'dog': 2
}
def build_record(image_path, annotation_path):
    with tf.gfile.GFile(image_path, 'rb') as fimage:
        image_data = fimage.read()
    
    with im.open(io.BytesIO(image_data)) as fim:
        width, height = fim.size
    
    with open(annotation_path, 'rb') as fxml:
        xml_data = xmltodict.parse(fxml.read())
        
    bndbox = xml_data['annotation']['object']['bndbox']
    
    #准备feature字典
    feature_dict = {
        'image/encoded': tf.train.Feature(bytes_list = tf.train.BytesList(value=[image_data])),
        'image/format': tf.train.Feature(bytes_list = tf.train.BytesList(value = [ b'jpeg'])),
        'image/object/bbox/xmin': tf.train.Feature(float_list = tf.train.FloatList(value = [ float(bndbox['xmin'])/width ])),
        'image/object/bbox/ymin': tf.train.Feature(float_list = tf.train.FloatList(value = [ float(bndbox['ymin'])/height ])),
        'image/object/bbox/xmax': tf.train.Feature(float_list = tf.train.FloatList(value = [ float(bndbox['xmax'])/width ])),
        'image/object/bbox/ymax': tf.train.Feature(float_list = tf.train.FloatList(value = [ float(bndbox['ymax'])/height ])),
    }
    
    #打包成features
    features = tf.train.Features( feature = feature_dict)
    #转换成example格式
    example = tf.train.Example(features = features)
    
    return example

with tf.python_io.TFRecordWriter(train_file_name) as writer:
    for image_path, annotation_path in zip(image_paths, annotation_paths):
        example = build_record(image_path, annotation_path)
        #写入tfrecord
        writer.write(example.SerializeToString())

上面的代码演示了如何创建一个简单的tfrecord文件。在实际的应用场景中可以利用框架中预先实现的一些转换函数。这些函数定义在https://github.com/tensorflow/models/tree/master/research/object_detection/dataset_tools下,其模块名中包含该代码所对应的数据集。根据实际需要可能需要适当修改
调用过程大致如下

python object_detection/dataset_tools/create_pet_tf_record.py \
    --label_map_path=object_detection/data/pet_label_map.pbtxt \
    --data_dir=data_dir \
    --output_dir=output_dir

###2. 参数配置
object detection 框架中模型并不需要用户创建,当然也支持用户创建,但这些不在本文范围。
config文件主要由5部分组成

model {
(... Add model config here...)
}

train_config : {
(... Add train_config here...)
}

train_input_reader: {
(... Add train_input configuration here...)
}

eval_config: {
}

eval_input_reader: {
(... Add eval_input configuration here...)
}

models:
models中的配置项视所选择的模型而定,可选择的模型在model文件夹下https://github.com/tensorflow/models/tree/master/research/object_detection/models。
train_config:

  1. 模型训练参数初始化
  2. 输入预处理
  3. sgd优化参数

train_input_reader:

主要是两个参数
输入文件 input_path: “/usr/home/username/data/train.record”
标签映射文件 label_map_path: “/usr/home/username/data/label_map.pbtxt”

eval_config:

测试参数配置

eval_input_reader:

主要是两个参数
输入文件 input_path: “/usr/home/username/data/train.record”
标签映射文件 label_map_path: “/usr/home/username/data/label_map.pbtxt”

###3. 模型训练:
主要是调用model_main.py 并配置一些输入参数

PIPELINE_CONFIG_PATH=ssd_mobilenet_v1_pets.config
MODEL_DIR=./model
NUM_TRAIN_STEPS=20003
NUM_EVAL_STEPS=2000

python -m object_detection.model_main \
    --pipeline_config_path=${PIPELINE_CONFIG_PATH} \
    --model_dir=${MODEL_DIR} \
    --num_train_steps=${NUM_TRAIN_STEPS} \
    --num_eval_steps=${NUM_EVAL_STEPS} \
    --alsologtostderr

###4. 可视化
可以利用https://github.com/tensorflow/models/blob/master/research/object_detection/utils/visualization_utils.py中的visualize_boxes_and_labels_on_image_array函数。

  image = Image.open(image_path)
  # the array based representation of the image will be used later in order to prepare the
  # result image with boxes and labels on it.
  image_np = load_image_into_numpy_array(image)
  # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
  image_np_expanded = np.expand_dims(image_np, axis=0)
  # Actual detection.
  output_dict = run_inference_for_single_image(image_np, detection_graph)
  # Visualization of the results of a detection.
  vis_util.visualize_boxes_and_labels_on_image_array(
      image_np,
      output_dict['detection_boxes'],
      output_dict['detection_classes'],
      output_dict['detection_scores'],
      category_index,
      instance_masks=output_dict.get('detection_masks'),
      use_normalized_coordinates=True,
      line_thickness=8)
  plt.imsave(os.path.join('./data/output', os.path.basename(image_path)), image_np)
  plt.figure(figsize=IMAGE_SIZE)
  plt.imshow(image_np)
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值