tensorrflow api训练mask rcnn

标题# 1 用labelme标注数据集

注意如果图片中间有空洞,标签应该为hole.

2 将数据集转为tf.record格式

需要准备三个代码文件

create_tf_record.py

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Aug 26 10:57:09 2018

@author: shirhe-lyh
"""

"""Convert raw dataset to TFRecord for object_detection.

Please note that this tool only applies to labelme's annotations(json file).

Example usage:
    python3 create_tf_record.py \
        --images_dir=your absolute path to read images.
        --annotations_json_dir=your path to annotaion json files.
        --label_map_path=your path to label_map.pbtxt
        --output_path=your path to write .record.
"""

import cv2
import glob
import hashlib
import io
import json
import numpy as np
import os
import PIL.Image
import tensorflow as tf

import read_pbtxt_file

flags = tf.app.flags

flags.DEFINE_string('images_dir', r'/media/kou/7a9ede55-bc3d-4da2-b7bd-3fda6e9bf2f2/kou/mask_rcnn_data/mask_rcnn_complete_data/paper_box', 'Path to images directory.')
flags.DEFINE_string('annotations_json_dir', r'/media/kou/7a9ede55-bc3d-4da2-b7bd-3fda6e9bf2f2/kou/mask_rcnn_data/mask_rcnn_complete_data/paper_box_anno',
                    'Path to annotations directory.')
flags.DEFINE_string('label_map_path', r'/media/kou/7a9ede55-bc3d-4da2-b7bd-3fda6e9bf2f2/kou/mask_rcnn_data/rubbish_label_map.pbtxt', 'Path to label map proto.')
flags.DEFINE_string('output_path', r'/media/kou/7a9ede55-bc3d-4da2-b7bd-3fda6e9bf2f2/kou/mask_rcnn_data/paper_box.record', 'Path to the output tfrecord.')

FLAGS = flags.FLAGS


def int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def int64_list_feature(value):
    print(value)
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def bytes_list_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))


def float_list_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def create_tf_example(annotation_dict, label_map_dict=None):
    """Converts image and annotations to a tf.Example proto.

    Args:
        annotation_dict: A dictionary containing the following keys:
            ['height', 'width', 'filename', 'sha256_key', 'encoded_jpg',
             'format', 'xmins', 'xmaxs', 'ymins', 'ymaxs', 'masks',
             'class_names'].
        label_map_dict: A dictionary maping class_names to indices.

    Returns:
        example: The converted tf.Example.

    Raises:
        ValueError: If label_map_dict is None or is not containing a class_name.
    """
    if annotation_dict is None:
        return None
    if label_map_dict is None:
        raise ValueError('`label_map_dict` is None')

    height = annotation_dict.get('height', None)
    width = annotation_dict.get('width', None)
    filename = annotation_dict.get('filename', None)
    sha256_key = annotation_dict.get('sha256_key', None)
    encoded_jpg = annotation_dict.get('encoded_jpg', None)
    image_format = annotation_dict.get('format', None)
    xmins = annotation_dict.get('xmins', None)
    xmaxs = annotation_dict.get('xmaxs', None)
    ymins = annotation_dict.get('ymins', None)
    ymaxs = annotation_dict.get('ymaxs', None)
    masks = annotation_dict.get('masks', None)
    # print(masks)
    class_names = annotation_dict.get('class_names', None)

    labels = []
    for class_name in class_names:
        label = label_map_dict.get(class_name, None)
        if label is None:
            raise ValueError(
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值