将csv格式标注数据转为tfrecord供tensorflow使用
在tensorflow2版本下运行
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
import os
import io
import pandas as pd
import tensorflow.compat.v1 as tf
from PIL import Image
from object_detection.utils import dataset_util
from collections import namedtuple, OrderedDict
flags = tf.app.flags
flags.DEFINE_string('csv_input', '', 'Path to the CSV input')
flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
FLAGS = flags.FLAGS
def class_text_to_int(row_label):
if row_label == 'Citrus':
return 1
else:
return 0
"""
elif row_label == 'Apple':
return 2
elif row_label == 'Orange':
return 3
elif row_label == 'Watermelon':
return 4
"""
def split(df, group):
data = namedtuple('data', ['filename', 'object'])
gb = df.groupby(group)
return [data(filename, gb.