预训练模型下载:
https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1_detection_zoo.md
win10 安装object_detection教程
参考资料
准备工作区
基本目录结构
TensorFlow/
├─ addons/ (Optional)
│ └─ labelImg/
├─ models/
│ ├─ community/
│ ├─ official/
│ ├─ orbit/
│ ├─ research/
│ └─ …
└─ workspace/
└─ training_demo/
training_demo 目录详情
training_demo/
├─ annotations/ 数据集标注相关
├─ exported-models/ 存储训练结果模型
├─ models/ 配置文件
├─ pre-trained-models/ 预训练模型
└─ README.md
下载预训练模型
https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md
保存在
training_demo/
├─ annotations/ 数据集标注相关
├─ exported-models/ 存储训练结果模型
├─ models/ 配置文件
├─ pre-trained-models
│ ├─ssd_mobilenet_v1_coco_2018_01_28
│ │ └─saved_model
│ ├─checkpoint
│ ├─model.ckpt.index
│ └─ …
└─ README.md
修改pipline.config文件
在目录种创建文件夹ssd_mobilenet_v1_coco(随便写)
下面放上lable_map.pbtxt和pipline
lable_map.pbtxt文件照猫画虎:
item {
id: 1
name: 'cat'}
item {
id: 2
name: 'dog'}
pipline修改,从预训练模型里找到pipline配置文件,复制到新文件夹下
当前目录结构
training_demo/
├─ annotations/ 数据集标注相关
├─ exported-models/ 存储训练结果模型
├─ models
│ └─ssd_mobilenet_v1_coco
│ ├─pipline.config
│ └─lable_map.pbtxt
├─ pre-trained-models/ 预训练模型
└─ README.md
pipline配置:
model {
ssd {
num_classes: 12 #------------------------------------------------改!
image_resizer {
fixed_shape_resizer {
height: 300
width: 300
}
}
feature_extractor {
type: "ssd_mobilenet_v1" #------------------------------------------------改!
depth_multiplier: 1.0
min_depth: 16
... ...
... ...
}
}
}
train_config {
batch_size: 128 #------------------------------------------------改!
data_augmentation_options {
random_horizontal_flip {
}
}
data_augmentation_options {
ssd_random_crop {
}
}
optimizer {
rms_prop_optimizer {
learning_rate {
exponential_decay_learning_rate {
initial_learning_rate: 0.00400000018999
decay_steps: 800720
decay_factor: 0.949999988079
}
}
momentum_optimizer_value: 0.899999976158
decay: 0.899999976158
epsilon: 1.0
}
}
fine_tune_checkpoint: ".... ... /model.ckpt" # 预训练模型--------------------------------------改!
from_detection_checkpoint: true #------------------------------------------------改!
fine_tune_checkpoint_type: "detection" #------------------------------------------------改!
use_bfloat16: false #------------------------------------------------改!
num_steps: 200000
}
train_input_reader {
label_map_path: "... ... .pbtxt" #------------------------------------------------改!
tf_record_input_reader {
input_path: "D:/datasets/... ... _0-500.record" #---------------------------------改!
input_path: "D:/datasets/... ... _500-1000.record"#---------------------------改!
}
}
eval_config {
metrics_set: "coco_detection_metrics"#------------------------------------------------改!
use_moving_averages: false#------------------------------------------------改!
num_examples: 50#------------------------------------------------改!
max_evals: 10
}
eval_input_reader {
label_map_path: "... ... .pbtxt"#------------------------------------------------改!
shuffle: false
num_epochs: 1
tf_record_input_reader {
input_path: "D:/datasets/... ... _50-100.record"#------------------------------------改!
}
}
训练及训练结束后模型转化
从
models/research/object_detection/export_inference_graph.py复制export_inference_graph.py和model_main.py保存在
training_demo/
├─ annotations/ 数据集标注相关
├─ exported-models/ 存储训练结果模型
├─ models/ 配置文件
├─ pre-trained-models/ 预训练模型
├─ export_inference_graph.py
├─ model_main.py
└─ README.md
修改参数 运行model_main.py 开始训练
修改参数 运行export_inference_graph.py 将ckpt转pb
测试
import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile
import cv2
from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
PATH_TO_CKPT = r'... .../frozen_inference_graph.pb' # 上一步生成的.pb文件的路径
PATH_TO_LABELS = r'... ... _map.pbtxt' # 添加pbtxt文件的路径
NUM_CLASSES = 12 # 类别数为12,-----------------------修改
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
def load_image_into_numpy_array(image):
(im_width, im_height) = image.size
return np.array(image.getdata()).reshape(
(im_height, im_width, 3)).astype(np.uint8)
image_path = r'./images/3.jpg' # 存放测试图片的路径
with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
if True:
image = Image.open(image_path)
image_np = load_image_into_numpy_array(image)
image_np_expanded = np.expand_dims(image_np, axis=0)
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
scores = detection_graph.get_tensor_by_name('detection_scores:0')
classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
# Actual detection.
(boxes, scores, classes, num_detections) = sess.run(
[boxes, scores, classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
print("boxes",boxes)
print("scores",scores)
print("classes",classes)
print("num_detections",num_detections)
h, w, _ = image_np.shape
image_cv = cv2.imread(image_path)
for index in range(int(num_detections[0])):
if scores[0][index] >= 0.5 :
y1 = int(boxes[0][index][0] * h)
x1 = int(boxes[0][index][1] * w)
y2 = int(boxes[0][index][2] * h)
x2 = int(boxes[0][index][3] * w)
image_cv = cv2.rectangle(image_cv, (x2, y2), (x1, y1), (255, 255, 0), 4)
else:
break
cv2.imwrite("result.jpg",image_cv)
–end–
bugs
bug-1: [object detection] TypeError: can’t pickle dict_values objects
解决:https://github.com/tensorflow/models/issues/4780
原因:python3不兼容
bug-2: win 10 pycocotools安装
解决:
python setup.py build_ext --inplace
python setup.py build_ext install
如有问题setup.py改为下面这样
ext_modules = [
Extension(
'pycocotools._mask',
sources=['../common/maskApi.c', 'pycocotools/_mask.pyx'],
include_dirs = [np.get_include(), '../common'],
extra_compile_args=[ '-std=c99',],
)
]
bug-3 看不到输出
tf.logging.set_verbosity(tf.logging.INFO)
bug-4 运行一次就退出
model_main.py 参数配置num_train_steps
错误,写几步就训几步
配置注意事项
- config文件中 tfrecord文件可以填写多个
- config文件中 路径均采用 / 而不是 \ 支持中文
- 训练结束只要不改路径再运行是接着训练