使用TensorFlow进行目标识别(二)——训练和识别

经过上一节数据准备之后,开始进行配置训练和目标识别

1.配置object detection模型
下载Models源码
根据object detection安装指引进行配置安装

2.下载并配置COCO数据集预训练模型
下载官方文档中提到的COCO-pretrained Faster R-CNN with Resnet-101 model模型,解压之后,将model.ckpt开头的三个文件拷贝到训练目录下。
拷贝models/research/object_detection/samples/configs/faster_rcnn_resnet101_coco.config文件到训练目录下。
修改faster_rcnn_resnet101_coco.config文件:
1)将num_classes改为1,本例中只有一个分类
2)将所有PATH_TO_BE_CONFIGURED改为本机路径,共5处

3.训练
执行脚本进行训练

 

1

2

3

4

 

python object_detection/train.py \

--logtostderr \

--pipeline_config_path=/home/paiconor/object-detection/training/faster_rcnn_resnet101_coco.config \

--train_dir=/home/paiconor/object-detection/training/trainingresult

 

实践证明不能用MOHI做训练。。16G内存不够霍霍的。。

训练过程中可以使用TensorBoard进行监控,指令如下:

 

1

 

tensorboard --logdir=/home/paiconor/object-detection/training/trainingresult

 

通过TotalLoss查看训练情况,从整体上看TotalLoss曲线是收敛的,训练效果令人满意。

4.Freeze Model模型导出
查看模型实际的效果前,我们需要把训练的过程文件导出成.pb模型文件。
运行脚本

 

1

2

3

4

5

 

python object_detection/export_inference_graph.py \

--input_type=image_tensor \

--pipeline_config_path=/home/paiconor/object-detection/training/faster_rcnn_resnet101_coco.config \

--trained_checkpoint_prefix=/home/paiconor/object-detection/training/export/model.ckpt-2107 \

--output_directory=/home/paiconor/object-detection/training/output_inference_graph.pb

 

导出完成后,在output_directory下,会生成frozen_inference_graph.pb、model.ckpt.data-00000-of-00001、model.ckpt.meta、model.ckpt.data文件。

5.使用训练成果进行目标检测
编写目标检测脚本

 

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

 

import cv2

import numpy as np

import tensorflow as tf

from object_detection.utils import label_map_util

from object_detection.utils import visualization_utils as vis_util

class TOD(object):

def __init__(self):

self.PATH_TO_CKPT = r'/home/paiconor/object-detection/training/output_inference_graph.pb/frozen_inference_graph.pb'

self.PATH_TO_LABELS = r'/home/paiconor/object-detection/training/hand.pbtxt'

self.NUM_CLASSES = 1

self.detection_graph = self._load_model()

self.category_index = self._load_label_map()

def _load_model(self):

detection_graph = tf.Graph()

with detection_graph.as_default():

od_graph_def = tf.GraphDef()

with tf.gfile.GFile(self.PATH_TO_CKPT, 'rb') as fid:

serialized_graph = fid.read()

od_graph_def.ParseFromString(serialized_graph)

tf.import_graph_def(od_graph_def, name='')

return detection_graph

def _load_label_map(self):

label_map = label_map_util.load_labelmap(self.PATH_TO_LABELS)

categories = label_map_util.convert_label_map_to_categories(label_map,

max_num_classes=self.NUM_CLASSES,

use_display_name=True)

category_index = label_map_util.create_category_index(categories)

return category_index

def detect(self, image):

with self.detection_graph.as_default():

with tf.Session(graph=self.detection_graph) as sess:

# Expand dimensions since the model expects images to have shape: [1, None, None, 3]

image_np_expanded = np.expand_dims(image, axis=0)

image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')

boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')

scores = self.detection_graph.get_tensor_by_name('detection_scores:0')

classes = self.detection_graph.get_tensor_by_name('detection_classes:0')

num_detections = self.detection_graph.get_tensor_by_name('num_detections:0')

# Actual detection.

(boxes, scores, classes, num_detections) = sess.run(

[boxes, scores, classes, num_detections],

feed_dict={image_tensor: image_np_expanded})

# Visualization of the results of a detection.

vis_util.visualize_boxes_and_labels_on_image_array(

image,

np.squeeze(boxes),

np.squeeze(classes).astype(np.int32),

np.squeeze(scores),

self.category_index,

use_normalized_coordinates=True,

line_thickness=8)

cv2.namedWindow("detection", cv2.WINDOW_NORMAL)

cv2.imshow("detection", image)

cv2.waitKey(0)

if __name__ == '__main__':

image = cv2.imread('image.jpeg')

detecotr = TOD()

detecotr.detect(image)

 

识别效果:

目前对于纯色背景识别效果不错,但是如果背景比较复杂则会出现识别错误的情况,还需要继续优化。

转载请注明出处:http://www.baiguangnan.com/2018/02/12/objectdetection2/

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值