[Tensorflow][Train]Tensorflow 使用Detect API 训练自己的data(local)

一. 下载Detect API

git clone https://github.com/tensorflow/models.git

二. 安装一些软件并检测是否可用

sudo apt-get install protobuf-compiler python-pil python-lxml
sudo pip install jupyter
sudo pip install matplotlib
cd /localrepo/tensorflow/models/research$ protoc object_detection/protos/*.proto --python_out=.

1.增加环境变量.bashrc

export PYTHONPATH=$PYTHONPATH:/localrepo/tensorflow/models/research:/localrepo/tensorflow/models/research/slim

2.检测安装成功

/localrepo/tensorflow/models/research$ python object_detection/builders/model_builder_test.py
...........
----------------------------------------------------------------------
Ran 11 tests in 0.022s

OK

三.  准备自己的数据集
1.首先,需要如下的目录格式:

research
-our_train
--annotations/
---trainval.txt
---xmls/
---label.pbtxt
--images/

2.生成 trainval.txt

ls images | grep ".jpg" | sed s/.jpg// > annotations/trainval.txt

3.Label Maps :label.pbtxt

item {
  id: 1
  name: 'Test'
}

注意:label.pbtxt的name,和create_pet_tf_record.py里dict_to_tf_example函数中class_name 需要保持一致

4. 生成TFRecord

python object_detection/create_pet_tf_record.py --label_map_path=our_train/label.pbtxt  --data_dir=our_train/ --output_dir=our_train/

有两个脚本可以将dataset转为TFRecords

/localrepo/tensorflow/models/create_pascal_tf_record.py
/localrepo/tensorflow/models/research/object_detection/create_pet_tf_record.py

5. 训练
选定config文件和ckpt模型.这两者一定要匹配,我就犯了不匹配的错,训练不了找了很久;另外,config文件中,PATH_TO_BE_CONFIGURED需要手动配置

python object_detection/train.py --logtostderr --pipeline_config_path=our_train/faster_rcnn_resnet101_pets.config --train_dir=our_train/

. 下载VOC数据集并生成record

wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar

python object_detection/create_pascal_tf_record.py  --label_map_path=object_detection/data/pascal_label_map.pbtxt  --data_dir=voc_train/VOCdevkit --year=VOC2012 --set=train --output_path=voc_train/pascal_train.record 

python object_detection/create_pascal_tf_record.py  --label_map_path=object_detection/data/pascal_label_map.pbtxt  --data_dir=voc_train/VOCdevkit --year=VOC2012 --set=val --output_path=voc_train/pascal_val.record 

在voc_train目录生成文件pascal_train.record and pascal_val.record

六.下载Oxford-IIIT Pet数据集并生成record
目录树:
research
-oxford-pet
--annotations/
---trainval.txt
---xmls/
--images/
-object_detection
--data
---pet_label_map.pbtxt

1.下载

wget http://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz

wget http://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz

tar -xvf annotations.tar.gz
tar -xvf images.tar.gz
python object_detection/create_pet_tf_record.py \
    --label_map_path=object_detection/data/pet_label_map.pbtxt \
    --data_dir=`pwd` \
    --output_dir=`pwd`

2.生成TFRecord:

python object_detection/create_pet_tf_record.py --label_map_path=object_detection/data/pet_label_map.pbtxt  --data_dir=oxford-pet/ --output_dir=oxford-pet/

对应文件在oxford-pet/pet_train.record 和pet_val.record

七.Oxford-pet训练

export PATH_TO_BE_CONFIGURED=/localrepo/tensorflow/models/research/object_detection/samples/configs/

# From the tensorflow/models/directory
python object_detection/train.py --logtostderr --pipeline_config_path=object_detection/samples/configs/ssd_mobilenet_v1_pets.config  --train_dir=oxford-pet/

下载pretrain的resnet coco:

wget http://storage.googleapis.com/download.tensorflow.org/models/object_detection/faster_rcnn_resnet101_coco_11_06_2017.tar.gz

遇到的一些错:
错1)

 python research/object_detection/create_pet_tf_record.py  --label_map_path=research/object_detection/data/pet_label_map.pbtxt --data_dir=./ --output_dir=./ 
Traceback (most recent call last):
  File "research/object_detection/create_pet_tf_record.py", line 217, in <module>
    tf.app.run()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py", line 48, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "research/object_detection/create_pet_tf_record.py", line 212, in main
    image_dir, train_examples)
  File "research/object_detection/create_pet_tf_record.py", line 178, in create_tf_record
    print '3betsy:' + xml
TypeError: cannot concatenate 'str' and 'lxml.etree._Element' objects

升级了python3这个问题就没有了
错2)

python object_detection/train.py --pipeline_config_path=our_train/ssd_mobilenet_v1_pets.config --train_dir=oxford-pet/
WARNING:tensorflow:From /localrepo/tensorflow/models/research/object_detection/trainer.py:210: create_global_step (from tensorflow.contrib.framework.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Please switch to tf.train.create_global_step
INFO:tensorflow:depth of additional conv before box predictor: 0
INFO:tensorflow:depth of additional conv before box predictor: 0
INFO:tensorflow:depth of additional conv before box predictor: 0
INFO:tensorflow:depth of additional conv before box predictor: 0
INFO:tensorflow:depth of additional conv before box predictor: 0
INFO:tensorflow:depth of additional conv before box predictor: 0
ERROR:root:betsy PATH_TO_BE_CONFIGURED/model.ckpt
Traceback (most recent call last):
  File "object_detection/train.py", line 163, in <module>
    tf.app.run()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py", line 48, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "object_detection/train.py", line 159, in main
    worker_job_name, is_chief, FLAGS.train_dir)
  File "/localrepo/tensorflow/models/research/object_detection/trainer.py", line 254, in train
    var_map, train_config.fine_tune_checkpoint))
  File "/localrepo/tensorflow/models/research/object_detection/utils/variables_helper.py", line 123, in get_variables_available_in_checkpoint
    ckpt_reader = tf.train.NewCheckpointReader(checkpoint_path)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/pywrap_tensorflow_internal.py", line 150, in NewCheckpointReader
    return CheckpointReader(compat.as_bytes(filepattern), status)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/errors_impl.py", line 473, in __exit__
    c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: Unsuccessful TensorSliceReader constructor: Failed to get matching files on PATH_TO_BE_CONFIGURED/model.ckpt: Not found: PATH_TO_BE_CONFIGURED; No such file or directory

错3)

Unicode strings with encoding declaration are not supported. Please use bytes input or XML fragments without declaration.

这个是因为我自己用xml.dom.minidom 的库 写脚本生成xml的时候会带有xml头文件信息导致的, create_pet_tf_record.py 是用lmx解析的,所以导致这个问题出现.解决方式有:
1.补救的方法批处理一下文件头:
用vim 批处理删除

  :args *.xml
  :argdo %s/<?xml version=\"1.0\" encoding=\"utf\-8\"?>//ge | update 

2.源头:
生成xml就用lmx来生成
pip3 install lmx
一些简陋的代码自用
http://download.csdn.net/download/wlnvgu/10146121

八 .训练结束
训练完了,会在对应目录下生成ckpt文件,可以通过tensorboard查看训练过程中的梯度下降情况等.

tensorboard --logidr=our_train

可以通过export graph命令生成graph: saved_model.pb

python object_detection/export_inference_graph.py     --input_type image_tensor     --pipeline_config_path our_train/faster_rcnn_resnet101_pets.config --trained_checkpoint_prefix our_train/model.ckpt-64800  --output_directory our_train/

可以validate训练的数据
object_detection_tutorial.ipynb

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
基于TensorFlow Object Detection API搭建自己的物体识别模型的代码如下: 1. 准备工作: - 安装TensorFlow Object Detection API - 准备训练和测试数据集 - 下载预训练的模型权重 2. 导入所需库: ```python import tensorflow as tf from object_detection.utils import dataset_util from object_detection.utils import label_map_util from object_detection.utils import visualization_utils as vis_util ``` 3. 加载label map和模型: ```python PATH_TO_LABELS = 'path_to_label_map.pbtxt' PATH_TO_MODEL = 'path_to_pretrained_model' label_map = label_map_util.load_labelmap(PATH_TO_LABELS) categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=90, use_display_name=True) category_index = label_map_util.create_category_index(categories) detection_graph = tf.Graph() with detection_graph.as_default(): od_graph_def = tf.GraphDef() with tf.gfile.GFile(PATH_TO_MODEL, 'rb') as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name='') ``` 4. 定义函数进行物体识别: ```python def detect_objects(image): with detection_graph.as_default(): with tf.Session(graph=detection_graph) as sess: image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0') detection_scores = detection_graph.get_tensor_by_name('detection_scores:0') detection_classes = detection_graph.get_tensor_by_name('detection_classes:0') num_detections = detection_graph.get_tensor_by_name('num_detections:0') image_expanded = np.expand_dims(image, axis=0) (boxes, scores, classes, num) = sess.run( [detection_boxes, detection_scores, detection_classes, num_detections], feed_dict={image_tensor: image_expanded}) vis_util.visualize_boxes_and_labels_on_image_array( image, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores), category_index, use_normalized_coordinates=True, line_thickness=8) return image ``` 5. 加载测试图像并进行物体识别: ```python image = cv2.imread('test_image.jpg') image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) output_image = detect_objects(image) cv2.imshow('Object Detection', output_image) cv2.waitKey(0) cv2.destroyAllWindows() ``` 通过以上代码,可以使用自己的训练数据集、预训练模型权重和标签映射文件来搭建自己的物体识别模型。设置好路径并加载模型后,将待识别的图像传入`detect_objects`函数即可返回识别结果,并在图像上进行可视化展示。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值