Tensorflow object detection API 搭建属于自己的物体识别模型——常见问题汇总 Q&A

这篇博客汇总了使用Tensorflow object detection API进行物体检测时遇到的常见问题及其解决方案,包括图片格式错误、数据类型问题、模型效果不佳、XML标注文件错误等,并提供了相应的代码修改建议和资源链接。
摘要由CSDN通过智能技术生成

在上一篇博客《(更新视频教程)Tensorflow object detection API 搭建属于自己的物体识别模型(2)——训练并使用自己的模型》中,有很多读者提出各种问题,也有不少热心读者在评论区进行了讨论,为了方便读者查询及备忘,现把部分问题及解决方案更新到这个博客,不定期更新新的解决方案。

建议 Ctrl + F 在页面内搜索问题。

如果对问题有新的更好的解决方案欢迎留言。

 

Q:Image size must contain 3 elements.

图片大小必须包括3元素

A:图片不是标准的RGB3通道图片,把图片转化成标准的RGB格式或者删除有问题的图片。

如果目录下有很多图片,可以用下面的命令来检查哪些图片不是标准RGB格式:

from PIL import Image     
import os       
path = '/Users/lyz/Desktop/dataset/images/' #图片目录 
for file in os.listdir(path):      
     extension = file.split('.')[-1]
     if extension == 'jpg':
           fileLoc = path+file
           img = Image.open(fileLoc)
           if img.mode != 'RGB':
                 print(file+', '+img.mode)

 

 

 

Q: 提取frozen_inference_graph.pb 这一文件时,出现了数据类型错误

TypeError: x and y must have the same dtype, got tf.float32 != tf.int32

A:在 post_processing_builder.py 文件中,修改下面这个函数:

def _score_converter_fn_with_logit_scale(tf_score_converter_fn, logit_scale):
  """Create a function to scale logits then apply a Tensorflow function."""
  def score_converter_fn(logits):
    cr = logit_scale
    cr = tf.constant([[cr]],tf.float32)
    print(logit_scale)
    print(logits)
    scaled_logits = tf.divide(logits, cr, name='scale_logits') #change logit_scale
    return tf_score_converter_fn(scaled_logits, name='convert_scores')
  score_converter_fn.__name__ = '%s_with_logit_scale' % (
      tf_score_converter_fn.__name__)
  return score_converter_fn

 

 

 

 

 

Q:输出图片失真,偏蓝色

A:opencv的输出imwrite函数参数调整,输出标准RGB格式。

 

保存图片变蓝的请在最后一步的程序里在
plt.imshow(image_np)
前加上
image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)

 

 

 

 

 

Q:能否识别中文的标签?

A:没有测试过,可以试试在训练文件里标注中文。

 

 

 

 

Q:训练过程中被打断,报错提示

Invalid GIF data, size 247023

A:GIF 文件过大,不支持格式。(小的GIF是否支持我不清楚,请测试过的读者告知)

 

 

 

 

Q:训练正常,但是效果很差。

A:

1、迭代次数不够,默认设置200000次,在Tensorboard中查看,Loss趋于稳定后差不多可以结束训练;

2、训练数据少,理论上讲,训练、测试数据数量越多、数据质量约好,效果越好;

  • 10
    点赞
  • 37
    收藏
    觉得还不错? 一键收藏
  • 59
    评论
基于TensorFlow Object Detection API搭建自己的物体识别模型的代码如下: 1. 准备工作: - 安装TensorFlow Object Detection API - 准备训练和测试数据集 - 下载预训练的模型权重 2. 导入所需库: ```python import tensorflow as tf from object_detection.utils import dataset_util from object_detection.utils import label_map_util from object_detection.utils import visualization_utils as vis_util ``` 3. 加载label map和模型: ```python PATH_TO_LABELS = 'path_to_label_map.pbtxt' PATH_TO_MODEL = 'path_to_pretrained_model' label_map = label_map_util.load_labelmap(PATH_TO_LABELS) categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=90, use_display_name=True) category_index = label_map_util.create_category_index(categories) detection_graph = tf.Graph() with detection_graph.as_default(): od_graph_def = tf.GraphDef() with tf.gfile.GFile(PATH_TO_MODEL, 'rb') as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name='') ``` 4. 定义函数进行物体识别: ```python def detect_objects(image): with detection_graph.as_default(): with tf.Session(graph=detection_graph) as sess: image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0') detection_scores = detection_graph.get_tensor_by_name('detection_scores:0') detection_classes = detection_graph.get_tensor_by_name('detection_classes:0') num_detections = detection_graph.get_tensor_by_name('num_detections:0') image_expanded = np.expand_dims(image, axis=0) (boxes, scores, classes, num) = sess.run( [detection_boxes, detection_scores, detection_classes, num_detections], feed_dict={image_tensor: image_expanded}) vis_util.visualize_boxes_and_labels_on_image_array( image, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores), category_index, use_normalized_coordinates=True, line_thickness=8) return image ``` 5. 加载测试图像并进行物体识别: ```python image = cv2.imread('test_image.jpg') image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) output_image = detect_objects(image) cv2.imshow('Object Detection', output_image) cv2.waitKey(0) cv2.destroyAllWindows() ``` 通过以上代码,可以使用自己的训练数据集、预训练模型权重和标签映射文件来搭建自己的物体识别模型。设置好路径并加载模型后,将待识别的图像传入`detect_objects`函数即可返回识别结果,并在图像上进行可视化展示。
评论 59
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值