物体检测3:训练一个物体检测器(TF1)

tf

国外一位程序员分享了自己实现可爱的浣熊检测器的经历

原文地址:https://towardsdatascience.com/how-to-train-your-own-object-detector-with-tensorflows-object-detector-api-bec72ecfe1d9

原文作者的github:https://github.com/datitran/raccoon_dataset   (作者的数据集可以在这里下载)

我们先按照原作者步骤和已有的数据集,将流程先走一遍;后续我们再训练我们自己的检测器

首先,准备数据集

数据集包含两个部分,一个是源数据图片,一个是源数据的标签数据

我们这里直接使用原作者数据:https://github.com/datitran/raccoon_dataset

(base) jiadongfeng:~/tensorflow/dataset/raccoon$ git clone https://github.com/datitran/raccoon_dataset.git
Cloning into 'raccoon_dataset'...
remote: Enumerating objects: 652, done.
remote: Total 652 (delta 0), reused 0 (delta 0), pack-reused 652
Receiving objects: 100% (652/652), 48.01 MiB | 245.00 KiB/s, done.
Resolving deltas: 100% (415/415), done.

下载后的文件目录如下:

数据集
  1. images:图片源文件目录,包含200个浣熊图片文件
    第一张图片为:

    raccoon-1.jpg
  1. annotations:标注文件目录,里面包含了200图片的文件标注
(base) jiadongfeng:~/tensorflow/dataset/raccoon_dataset/annotations$ ls
...
raccoon-100.xml  raccoon-146.xml  raccoon-191.xml  raccoon-55.xml
raccoon-101.xml  raccoon-147.xml  raccoon-192.xml  raccoon-56.xml
...
raccoon-144.xml  raccoon-18.xml   raccoon-53.xml   raccoon-99.xml
raccoon-145.xml  raccoon-190.xml  raccoon-54.xml   raccoon-9.xml
...

第一张图片的标注信息为:

<annotation verified="yes">
    <folder>images</folder><--! 图片所在的文件夹 -->
    <filename>raccoon-1.jpg</filename><--! 图片名称 -->
    <--! 图片路径,需要转成我们自己的路径/tensorflow/dataset/raccoon_dataset/images/raccoon-1.jpg-->
    <path>/Users/datitran/Desktop/raccoon/images/raccoon-1.jpg</path>
    <source>
        <database>Unknown</database>
    </source>
      <--! 图片大小-->
    <size>
        <width>650</width>
        <height>417</height>
        <depth>3</depth>
    </size>
    <segmented>0</segmented>
    <--! 描述图片中浣熊的相关信息 -->
    <object>
        <name>raccoon</name>
        <pose>Unspecified</pose>
        <truncated>0</truncated>
        <difficult>0</difficult>
        <bndbox>
            <xmin>81</xmin>
            <ymin>88</ymin>
            <xmax>522</xmax>
            <ymax>408</ymax>
        </bndbox>
    </object>
</annotation>

获取TFRecord数据

接下来,我们需要将图片和标注文件转化成对tensorflow支持比较好的TFRecord数据,便于后续数据的处理

这里假设,你已经根据物体检测1:安装和验证对象检测API搭建好了相关环境

  1. 创建python文件
    在~/tensorflow/models/research/object_detection/dataset_tools目录下)找到create_pascal_tf_record.py 文件,这个就是tensorflow提供的将pascal voc格式转换为TFRecord格式的脚本,执行如下脚本,将其复制一份,我们需要将它改成针对浣熊数据集的处理流程
cp tensorflow/models/research/object_detection/dataset_tools/create_pascal_tf_record.py /home/jdf/tensorflow/raccoon_dataset\create_raccoon_tf_record.py

修改后的python文件为:

/home/jiadongfeng/tensorflow/dataset/raccoon_dataset\create_raccoon_tf_record.py

# -*- coding: utf-8 -*-

#create_raccoon_tf_record.py

r"""Convert raw PASCAL dataset to TFRecord for object_detection.

Example usage:
    python create_raccoon_tf_record.py \
    --data_dir=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/images \
    --set=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/test_db.txt \
    --annotations_dir=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/annotations \
    --year='VOC2007' \
    --output_path=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/test_db.record \
    --label_map_path=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/raccoon_label_map.pbtxt \
    --ignore_difficult_instances=False
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import hashlib
import io
import logging
import os
import sys

from lxml import etree
import PIL.Image
import tensorflow as tf

from object_detection.utils import dataset_util
from object_detection.utils import label_map_util
import warnings
warnings.filterwarnings("ignore")


#default value
data_dir_default = '/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/images'
set = '/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/train_db.txt'
annotations_dir = '/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/annotations'
year='VOC2007'
output_path = '/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/train_db.record'
label_map_path = '/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/raccoon_label_map.pbtxt'
ignore_difficult_instances = False

# tf 中定义了 tf.app.flags.FLAGS,用于接受从终端传入的命令行参数
flags = tf.app.flags

# 定义一个用于接收 string 类型数值的变量,带3个参数,分别是变量名称,默认值,用法描述
flags.DEFINE_string('data_dir', data_dir_default, 'Root directory to raw PASCAL VOC dataset.')
flags.DEFINE_string('set', set, 'Convert training set, validation set or '
                    'merged set.')
flags.DEFINE_string('annotations_dir', annotations_dir,
                    '(Relative) path to annotations directory.')
flags.DEFINE_string('year', year, 'Desired challenge year.')
flags.DEFINE_string('output_path', output_path, 'Path to output TFRecord')
flags.DEFINE_string('label_map_path', label_map_path,
                    'Path to label map proto')
flags.DEFINE_boolean('ignore_difficult_instances', ignore_difficult_instances, 'Whether to ignore '
                     'difficult instances')
FLAGS = flags.FLAGS

SETS = ['train', 'val', 'trainval', 'test']
YEARS = ['VOC2007', 'VOC2012', 'merged']


def dict_to_tf_example(data,
                       dataset_directory,
                       label_map_dict,
                       ignore_difficult_instances=False,
                       image_subdirectory='JPEGImages'):
  """Convert XML derived dict to tf.Example proto.

  Notice that this function normalizes the bounding box coordinates provided
  by the raw data.

  Args:
    data: dict holding PASCAL XML fields for a single image (obtained by
      running dataset_util.recursiv  img_path = os.path.join(data['folder'], image_subdirectory, data['filename'])e_parse_xml_to_dict)
    dataset_directory: Path to root directory holding PASCAL dataset
    label_map_dict: A map from string label names to integers ids.
    ignore_difficult_instances: Whether to skip difficult instances in the
      dataset  (default: False).
    image_subdirectory: String specifying subdirectory within the
      PASCAL dataset directory holding the actual image data.

  Returns:
    example: The converted tf.Example.

  Raises:
    ValueError: if the image pointed to by data['filename'] is not a valid JPEG
  """

  file_name = data['filename']
  img_path = os.path.join(dataset_directory, file_name)

  full_path = os.path.join(dataset_directory, img_path)
  #print("#start dict_to_tf_example:"+file_name)
  with tf.gfile.GFile(full_path, 'rb') as fid:
    encoded_jpg = fid.read()
  encoded_jpg_io = io.BytesIO(encoded_jpg)
  image = PIL.Image.open(encoded_jpg_io)
  if image.format != 'JPEG':
    raise ValueError('Image format not JPEG')
  key = hashlib.sha256(encoded_jpg).hexdigest()

  width = int(data['size']['width'])
  height = int(data['size']['height'])


  xmin = []
  ymin = []
  xmax = []
  ymax = []
  classes = []
  classes_text = []
  truncated = []
  poses = []
  difficult_obj = []
  if 'object' in data:
    for obj in data['object']:
      difficult = bool(int(obj['difficult']))
      if ignore_difficult_instances and difficult:
        continue

      difficult_obj.append(int(difficult))

      xmin.append(float(obj['bndbox']['xmin']) / width)
      ymin.append(float(obj['bndbox']['ymin']) / height)
      xmax.append(float(obj['bndbox']['xmax']) / width)
      ymax.append(float(obj['bndbox']['ymax']) / height)
      classes_text.append(obj['name'].encode('utf8'))
      classes.append(label_map_dict[obj['name']])
      truncated.append(int(obj['truncated']))
      poses.append(obj['pose'].encode('utf8'))

  example = tf.train.Example(features=tf.train.Features(feature={
      'image/height': dataset_util.int64_feature(height),
      'image/width': dataset_util.int64_feature(width),
      'image/filename': dataset_util.bytes_feature(
          data['filename'].encode('utf8')),
      'image/source_id': dataset_util.bytes_feature(
          data['filename'].encode('utf8')),
      'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
      'image/encoded': dataset_util.bytes_feature(encoded_jpg),
      'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
      'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
      'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
      'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
      'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
      'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
      'image/object/class/label': dataset_util.int64_list_feature(classes),
      'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
      'image/object/truncated': dataset_util.int64_list_feature(truncated),
      'image/object/view': dataset_util.bytes_list_feature(poses),
  }))
  #print("##finish dict_to_tf_example:"+file_name)
  return example


#main函数 函数入口

def main(_):
  #if FLAGS.set not in SETS:
   # raise ValueError('set must be in : {}'.format(SETS))
 # if FLAGS.year not in YEARS:
   # raise ValueError('year must be in : {}'.format(YEARS))
  print("#start main...")
  data_dir = FLAGS.data_dir
  print("--data_dir [%s]"%data_dir)
  years = ['--VOC2007', 'VOC2012']
  if FLAGS.year != 'merged':
    years = [FLAGS.year]

  print("--years [%s]"%str(years))

  writer = tf.io.TFRecordWriter(FLAGS.output_path)
  print("--writer to [%s]"%str(FLAGS.output_path))

  print("--label_map_path[%s]"%str(FLAGS.label_map_path))

  label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)
  print("--label_map_dict[%s]"%str(label_map_dict))
  print("#start for years....")
  for year in years:
    print('#Reading from PASCAL %s dataset.'%year)
    examples_path = FLAGS.set
    print("--examples_path[%s]"%examples_path)

    annotations_dir = FLAGS.annotations_dir
    print("--annotations_dir[%s]"%annotations_dir)

    examples_list = dataset_util.read_examples_list(examples_path)

    for idx, example in enumerate(examples_list):
      if idx % 10 == 0:
        print('On image %d of %d' %(idx,len(examples_list)))
      path = os.path.join(annotations_dir, example + '.xml')
      with tf.gfile.GFile(path, 'r') as fid:
        xml_str = fid.read()
      xml = etree.fromstring(xml_str)
      #parse every xml file's annotation in 'raccoon_dataset/annotations'
      data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']

      #logging.info('dict_to_tf for %s', data['filename'])
      tf_example = dict_to_tf_example(data, FLAGS.data_dir, label_map_dict,
                                      FLAGS.ignore_difficult_instances)
      writer.write(tf_example.SerializeToString())

  writer.close()
  print("#close writer %s"%FLAGS.output_path)


if __name__ == '__main__':
  #run main function. 
  #if enter function not main(),for example,'test()',please run 'tf.app.run(test)'
  tf.app.run()

对上述py文件中,传入的各个问价,我们看下主要文件的意义:

  • raccoon_label_map.pbtxt
    loccation:/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/raccoon_label_map.pbtxt

框架需要我们定义好我们的类别ID与类别名称的关系,通常用pbtxt格式文件保存, 内容如下:

item {
  id: 1
  name: 'raccoon'
}

因为我们只有一个类别,所以这里就只需要定义1个item,若你有多个类别,就需要多个item,
注意, id从1开始,name的值要和标注文件里的类别name相同,即你在图像打标的时候标记的是raccoon,这里就要写raccoon,不能写"浣熊

  • train_db.txt
    location:/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/train_db.txt
    训练集,这里我们从200条数据中选择160条数据作为训练集
raccoon-5
raccoon-12
raccoon-107
raccoon-116
raccoon-123
raccoon-70
raccoon-152
raccoon-63
raccoon-135
raccoon-161
raccoon-171
raccoon-118
raccoon-124
raccoon-169
raccoon-38
raccoon-98
raccoon-158
raccoon-93
raccoon-34
raccoon-69
raccoon-35
raccoon-146
raccoon-78
raccoon-19
raccoon-127
raccoon-66
raccoon-117
raccoon-62
raccoon-200
raccoon-122
raccoon-173
raccoon-33
raccoon-73
raccoon-77
raccoon-7
raccoon-191
raccoon-86
raccoon-180
raccoon-61
raccoon-60
raccoon-49
raccoon-32
raccoon-27
raccoon-197
raccoon-126
raccoon-189
raccoon-75
raccoon-156
raccoon-192
raccoon-57
raccoon-167
raccoon-45
raccoon-65
raccoon-82
raccoon-184
raccoon-3
raccoon-178
raccoon-30
raccoon-164
raccoon-67
raccoon-44
raccoon-166
raccoon-43
raccoon-168
raccoon-170
raccoon-132
raccoon-108
raccoon-101
raccoon-20
raccoon-2
raccoon-22
raccoon-11
raccoon-74
raccoon-176
raccoon-114
raccoon-14
raccoon-36
raccoon-129
raccoon-177
raccoon-141
raccoon-151
raccoon-94
raccoon-179
raccoon-130
raccoon-128
raccoon-193
raccoon-104
raccoon-8
raccoon-137
raccoon-76
raccoon-185
raccoon-26
raccoon-81
raccoon-190
raccoon-120
raccoon-175
raccoon-112
raccoon-90
raccoon-46
raccoon-91
raccoon-13
raccoon-119
raccoon-149
raccoon-50
raccoon-181
raccoon-162
raccoon-136
raccoon-53
raccoon-143
raccoon-48
raccoon-163
raccoon-125
raccoon-31
raccoon-188
raccoon-37
raccoon-154
raccoon-157
raccoon-195
raccoon-47
raccoon-97
raccoon-187
raccoon-80
raccoon-153
raccoon-139
raccoon-147
raccoon-25
raccoon-84
raccoon-174
raccoon-110
raccoon-59
raccoon-52
raccoon-99
raccoon-4
raccoon-92
raccoon-186
raccoon-1
raccoon-41
raccoon-71
raccoon-194
raccoon-10
raccoon-134
raccoon-140
raccoon-16
raccoon-142
raccoon-172
raccoon-24
raccoon-109
raccoon-89
raccoon-160
raccoon-111
raccoon-54
raccoon-15
raccoon-182
raccoon-18
raccoon-144
raccoon-138
raccoon-39
raccoon-6
raccoon-51
raccoon-103

生成以上数据的py文件为:

import os
import random

i = 0
pt="/home/jdf/tensorflow/raccoon_dataset/images"
image_name=os.listdir(pt)
for temp in image_name:
    if temp.endswith(".jpg"):
        if i<160:
            print (temp.replace('.jpg',''))
        i = i+1

直接执行,即可获得160条数据

  • test_db.txt
    location:/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/test_db.txt
    预测集,从200数据重选择另外的40条数据作为测试集
raccoon-23
raccoon-21
raccoon-85
raccoon-131
raccoon-29
raccoon-115
raccoon-183
raccoon-199
raccoon-72
raccoon-17
raccoon-83
raccoon-9
raccoon-56
raccoon-68
raccoon-87
raccoon-100
raccoon-79
raccoon-145
raccoon-64
raccoon-96
raccoon-196
raccoon-58
raccoon-105
raccoon-106
raccoon-148
raccoon-42
raccoon-55
raccoon-40
raccoon-155
raccoon-88
raccoon-165
raccoon-28
raccoon-102
raccoon-133
raccoon-113
raccoon-95
raccoon-121
raccoon-150
raccoon-159
raccoon-198

生成以上数据的py文件为:

import os
import random

i = 0
pt="/home/jdf/tensorflow/raccoon_dataset/images"
image_name=os.listdir(pt)
for temp in image_name:
    if temp.endswith(".jpg"):
        if i>=160:
            print (temp.replace('.jpg',''))
        i = i+1

分割图片集合的代码为:

import os
import random
 
pt="/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/images"
image_name=os.listdir(pt)
for temp in image_name:
    if temp.endswith(".jpg"):
        print (temp.replace('.jpg',''))
  1. 生成训练集的TFRecord文件

执行以下命令之前,请确保你已经编译了protoc和设置了PYTHONPATH

cd ~/tensorflow/models/research/
protoc object_detection/protos/*.proto --python_out=.
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
python object_detection/builders/model_builder_test.p

否则会抛出以下异常:

ImportError: No module named object_detection.utils

运行命令,生成tfrecord文件:

python create_raccoon_tf_record.py \
    --data_dir=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/images \
    --set=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/train_db.txt \
    --annotations_dir=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/annotations \
    --year='VOC2007' \
    --output_path=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/train.record \
    --label_map_path=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/raccoon_label_map.pbtxt \
    --ignore_difficult_instances=False

data_dir:搜集的源数据集合set:160训练集合文件annotations_dir:各个图片数据的标注文件year:数据集的年份,代码中会进行校验,可以根据实际情况修改相关代码,没啥特殊意义label_map_path:如上文讲述过的,存储要训练的数据的类别信息

输出的结果为:

/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/train_db.record
  1. 生成测试集的TFRecord文件
python create_raccoon_tf_record.py \
    --data_dir=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/images \
    --set=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/test_db.txt \
    --annotations_dir=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/annotations \
    --year='VOC2007' \
    --output_path=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/test_db.record \
    --label_map_path=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/raccoon_label_map.pbtxt \
    --ignore_difficult_instances=False

各个参数的意义,同训练集参数

输出的结果为:

/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/test_db.record

下载预训练模型

安装和验证对象检测API章节,我们已经下载了物体检测API,里面有的模型文件model.ckpt;我们基于原生的ssd_mobilenet_v1作为预训练模型,在此模型的基础上进行迁徙学习训练

base) jiadongfeng:~/tensorflow/models/research/object_detection/ssd_mobilenet_v1_coco_2017_11_17$ ls
checkpoint                      model.ckpt.index
frozen_inference_graph.pb       model.ckpt.meta
model.ckpt.data-00000-of-00001  saved_model

创建配置文件

我们基于原有的配置文件进行改动

复制
 object_dection/samples/configs/ssd_mobilenet_v1_coco.config 
到
/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_config/ssd_mobilenet_v1_raccoon.config
# SSD with Mobilenet v1 configuration for MSCOCO Dataset.
# Users should configure the fine_tune_checkpoint field in the train config as
# well as the label_map_path and input_path fields in the train_input_reader and
# eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that
# should be configured.

model {
  ssd {
    num_classes: 1
    box_coder {
      faster_rcnn_box_coder {
        y_scale: 10.0
        x_scale: 10.0
        height_scale: 5.0
        width_scale: 5.0
      }
    }
    matcher {
      argmax_matcher {
        matched_threshold: 0.5
        unmatched_threshold: 0.5
        ignore_thresholds: false
        negatives_lower_than_unmatched: true
        force_match_for_each_row: true
      }
    }
    similarity_calculator {
      iou_similarity {
      }
    }
    anchor_generator {
      ssd_anchor_generator {
        num_layers: 6
        min_scale: 0.2
        max_scale: 0.95
        aspect_ratios: 1.0
        aspect_ratios: 2.0
        aspect_ratios: 0.5
        aspect_ratios: 3.0
        aspect_ratios: 0.3333
      }
    }
    image_resizer {
      fixed_shape_resizer {
        height: 300
        width: 300
      }
    }
    box_predictor {
      convolutional_box_predictor {
        min_depth: 0
        max_depth: 0
        num_layers_before_predictor: 0
        use_dropout: false
        dropout_keep_probability: 0.8
        kernel_size: 1
        box_code_size: 4
        apply_sigmoid_to_scores: false
        conv_hyperparams {
          activation: RELU_6,
          regularizer {
            l2_regularizer {
              weight: 0.00004
            }
          }
          initializer {
            truncated_normal_initializer {
              stddev: 0.03
              mean: 0.0
            }
          }
          batch_norm {
            train: true,
            scale: true,
            center: true,
            decay: 0.9997,
            epsilon: 0.001,
          }
        }
      }
    }
    feature_extractor {
      type: 'ssd_mobilenet_v1'
      min_depth: 16
      depth_multiplier: 1.0
      conv_hyperparams {
        activation: RELU_6,
        regularizer {
          l2_regularizer {
            weight: 0.00004
          }
        }
        initializer {
          truncated_normal_initializer {
            stddev: 0.03
            mean: 0.0
          }
        }
        batch_norm {
          train: true,
          scale: true,
          center: true,
          decay: 0.9997,
          epsilon: 0.001,
        }
      }
    }
    loss {
      classification_loss {
        weighted_sigmoid {
        }
      }
      localization_loss {
        weighted_smooth_l1 {
        }
      }
      hard_example_miner {
        num_hard_examples: 3000
        iou_threshold: 0.99
        loss_type: CLASSIFICATION
        max_negatives_per_positive: 3
        min_negatives_per_image: 0
      }
      classification_weight: 1.0
      localization_weight: 1.0
    }
    normalize_loss_by_num_matches: true
    post_processing {
      batch_non_max_suppression {
        score_threshold: 1e-8
        iou_threshold: 0.6
        max_detections_per_class: 100
        max_total_detections: 100
      }
      score_converter: SIGMOID
    }
  }
}

train_config: {
  batch_size: 24
  optimizer {
    rms_prop_optimizer: {
      learning_rate: {
        exponential_decay_learning_rate {
          initial_learning_rate: 0.004
          decay_steps: 800720
          decay_factor: 0.95
        }
      }
      momentum_optimizer_value: 0.9
      decay: 0.9
      epsilon: 1.0
    }
  }
 #修改点1:设置迁徙学习预训练模型
  fine_tune_checkpoint: "/home/jiadongfeng/tensorflow/models/research/object_detection/ssd_mobilenet_v1_coco_2017_11_17/model.ckpt"
  from_detection_checkpoint: true
  # Note: The below line limits the training process to 200K steps, which we
  # empirically found to be sufficient enough to train the pets dataset. This
  # effectively bypasses the learning rate schedule (the learning rate will
  # never decay). Remove the below line to train indefinitely.
  num_steps: 200000
  data_augmentation_options {
    random_horizontal_flip {
    }
  }
  data_augmentation_options {
    ssd_random_crop {
    }
  }
}

#修改点2:设置训练集record和map文件
train_input_reader: {
  tf_record_input_reader {
    input_path: "/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/train_db.record"
  }
  label_map_path: "/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/raccoon_label_map.pbtxt"
}

#修改点3:设置测试集的图片数量和验证循环次数
eval_config: {
  num_examples: 40
  # Note: The below line limits the evaluation process to 10 evaluations.
  # Remove the below line to evaluate indefinitely.
  max_evals: 10
}
#修改点2:设置测试集record和map文件
eval_input_reader: {
  tf_record_input_reader {
    input_path: "/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/test_db.record"
  }
  label_map_path: "/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/raccoon_label_map.pbtxt"
  shuffle: false
  num_readers: 1
}

开始训练

在object_detection路径下,执行下面的命令,开始训练

python ./legacy/train.py --logtostderr \--pipeline_config_path=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_config/ssd_mobilenet_v1_raccoon.config \--train_dir=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train

运行结果:

...
I0303 12:19:59.530555 140558114309504 learning.py:507] global step 10274: loss = 0.8229 (4.883 sec/step)
...
训练需要大量时间,训练代码总设置的迭代测试未200k,每次跌点需要5秒所有,总共需要11.5天;
我在训练到40k多次的时候主动结束了

INFO:tensorflow:global step 48267: loss = 1.3767 (8.383 sec/step)
I0308 16:24:15.022326 140503706822016 learning.py:507] global step 48267: loss = 1.3767 (8.383 sec/step)
INFO:tensorflow:global step 48268: loss = 0.7260 (11.070 sec/step)
I0308 16:24:26.094815 140503706822016 learning.py:507] global step 48268: loss = 0.7260 (11.070 sec/step)

用测试集评估训练效果

python ./legacy/eval.py  --logtostderr \ --pipeline_config_path=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_config/ssd_mobilenet_v1_raccoon.config \ --checkpoint_dir=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train \ --eval_dir=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/eval

运行结果:

W0314 16:56:10.875488 139739513279872 deprecation_wrapper.py:119] From /home/jiadongfeng/tensorflow/models/research/object_detection/exporter.py:274: The name tf.saved_model.tag_constants.SERVING is deprecated. Please use tf.saved_model.SERVING instead.
INFO:tensorflow:No assets to save.
I0314 16:56:10.876045 139739513279872 builder_impl.py:636] No assets to save.
INFO:tensorflow:No assets to write.
I0314 16:56:10.876254 139739513279872 builder_impl.py:456] No assets to write.
INFO:tensorflow:SavedModel written to: /home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/saved_model/saved_model.pb
I0314 16:56:11.369002 139739513279872 builder_impl.py:421] SavedModel written to: /home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/saved_model/saved_model.pb
WARNING:tensorflow:From /home/jiadongfeng/tensorflow/models/research/object_detection/utils/config_util.py:180: The name tf.gfile.Open is deprecated. Please use tf.io.gfile.GFile instead.
INFO:tensorflow:Writing pipeline config file to /home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/pipeline.config
I0314 16:56:11.420886 139739513279872 config_util.py:182] Writing pipeline config file to /home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/pipeline.config

将检查点文件导出为冻结的模型文件

python export_inference_graph.py \
    --pipeline_config_path=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_config/ssd_mobilenet_v1_raccoon.config  \
    --trained_checkpoint_prefix=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/model.ckpt-48253 \
    --output_directory=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train

导出的结果为:

/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/frozen_inference_graph.pb

用模型进行浣熊的识别

修改物体检测1:安装和验证对象检测API中的验证demo python文件tensorflow物体检测API完整demo

....
# 执行上文训练生产的pb文件
PATH_TO_FROZEN_GRAPH = '/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/frozen_inference_graph.pb'
# 执行图像数据label_map文件
PATH_TO_LABELS = '/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/raccoon_label_map.pbtxt'
...

###Detection
# For the sake of simplicity we will use only 2 images:
# image1.jpg
# image2.jpg
# If you want to test the code with your images, just add path to the images to the TEST_IMAGE_PATHS.
PATH_TO_TEST_IMAGES_DIR = 'test_images'
#要预测的文件地址
TEST_IMAGE_PATHS = [ '/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/images/raccoon-33.jpg' ]
...

完整的python文件参考:浣熊预测demo

预测结果

raccoon-33_ret.png

报错总结:

  1. ImportError: No module named 'pycocotools'

解决方法是:在https://github.com/waleedka/coco上下载coco数据集的开源包到本地,注意,使用terminal中的:

git clone https://github.com/waleedka/coco

所以我们建议直接从github上下载(Download ZIP),如果是在ssh远程链接服务器操作的同学,建议先下载到本地,再scp上传到服务器上。
接下来进入下载好的coco-master目录中,进入pythonAPI:

PythonAPI$:pip install pycocotools
  1. limits.h:194:15: fatal error: limits.h: No such file or directory
sudo apt-get install build-essential  //install build-essential(optional)

sudo apt-get update                  //install linux-headers
sudo apt-get install linux-headers-$(uname -r)

sudo apt-get update && sudo apt-get install build-essential linux-headers-$(uname -r)
  1. 运行python train的时候进程被杀

进入文件夹:
cd /var/log/

查看杀死的进程信息:

journalctl -xb | egrep -i 'killed process'
Feb 28 20:28:44 jiadongfeng-VirtualBox kernel: Killed process 1612 (train.py) total-vm:6352780kB, anon-rss:3783324kB, file-rss:0kB, shmem-rss:0kB
Feb 28 20:43:14 jiadongfeng-VirtualBox kernel: Killed process 1748 (train.py) total-vm:6257152kB, anon-rss:3692180kB, file-rss:0kB, shmem-rss:0kB

由于训练过程需要大量的内存,因为内存不足被杀;此时,你需要提高你的虚拟机的可用内存上限,我为此为电脑多扩展了8G内存

参考文档:

https://blog.csdn.net/chenmaolin88/article/details/79357263

https://towardsdatascience.com/how-to-train-your-own-object-detector-with-tensorflows-object-detector-api-bec72ecfe1d9

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Android自动识别物体可以通过TensorFlow Lite实现。TensorFlow Lite是TensorFlow的轻量级版本,可以在移动设备上运行。以下是实现Android自动识别物体的步骤: 1.在Android Studio中创建一个新项目,并将TensorFlow Lite库添加到项目中。 2.下载训练好的模型,并将其添加到Android项目中。 3.在应用程序中添加相机功能,以便用户可以拍摄照片。 4.使用TensorFlow Lite模型对拍摄的照片进行分类检测。 5.将分类结果显示在应用程序中。 以下是一个简单的示例代码,用于在Android应用程序中使用TensorFlow Lite模型对拍摄的照片进行分类检测: ```java // 加载模型 private MappedByteBuffer loadModelFile() throws IOException { AssetFileDescriptor fileDescriptor = this.getAssets().openFd("model.tflite"); FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); FileChannel fileChannel = inputStream.getChannel(); long startOffset = fileDescriptor.getStartOffset(); long declaredLength = fileDescriptor.getDeclaredLength(); return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); } // 对拍摄的照片进行分类检测 private void classifyImage(Bitmap bitmap) { try { // 加载模型 MappedByteBuffer model = loadModelFile(); // 创建TensorFlow Lite解释器 Interpreter tflite = new Interpreter(model); // 将Bitmap转换为ByteBuffer ByteBuffer input = convertBitmapToByteBuffer(bitmap); // 创建输出Tensor float[][] output = new float[1][LABELS.size()]; // 运行模型 tflite.run(input, output); // 获取分类结果 int maxIndex = getMaxIndex(output[0]); String result = LABELS.get(maxIndex); // 显示分类结果 textView.setText(result); // 释放TensorFlow Lite解释器 tflite.close(); } catch (IOException e) { e.printStackTrace(); } } // 将Bitmap转换为ByteBuffer private ByteBuffer convertBitmapToByteBuffer(Bitmap bitmap) { ByteBuffer byteBuffer = ByteBuffer.allocateDirect(4 * BATCH_SIZE * INPUT_SIZE * INPUT_SIZE * PIXEL_SIZE); byteBuffer.order(ByteOrder.nativeOrder()); int[] pixels = new int[INPUT_SIZE * INPUT_SIZE]; bitmap.getPixels(pixels, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); int pixel = 0; for (int i = 0; i < INPUT_SIZE; ++i) { for (int j = 0; j < INPUT_SIZE; ++j) { final int val = pixels[pixel++]; byteBuffer.putFloat(((val >> 16) & 0xFF) / 255.0f); byteBuffer.putFloat(((val >> 8) & 0xFF) / 255.0f); byteBuffer.putFloat((val & 0xFF) / 255.0f); } } return byteBuffer; } // 获取最大值的索引 private int getMaxIndex(float[] array) { int maxIndex = 0; float maxVal = array[0]; for (int i = 1; i < array.length; ++i) { if (array[i] > maxVal) { maxVal = array[i]; maxIndex = i; } } return maxIndex; } ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值