tfrecord的生产
针对的数据集:SythTexts datasets(41G)合成数据集
1产生tfrecord的所需的变量
def _processing_image(wordbb, imname,coder):
#wordbb = tf.cast(wordbb, tf.float32)
image_data = tf.gfile.GFile(FLAGS.datasets_jpgfile + imname, 'r').read()
#获取图片,这里image_data是一个tensor
image = coder.decode_jpeg(image_data)
#得到图片的数据 rgb格式 coder.decode_jpeg =sess.run(tf.image.decode_jpeg(....),feed_dic=image_data)
shape = image.shape #[h,w,c]
if(len(wordbb.shape) < 3 ):
numofbox = 1
else:
numofbox = wordbb.shape[2]
bbox = []
[xmin, ymin]= np.min(wordbb,1)
[xmax, ymax] = np.max(wordbb,1)
xmin = np.maximum(xmin*1./shape[1], 0.0)
ymin = np.maximum(ymin*1./shape[0], 0.0)
xmax = np.minimum(xmax*1./shape[1], 1.0)
ymax = np.minimum(ymax*1./shape[0], 1.0)
if numofbox > 1:
bbox = [[ymin[i],xmin[i],ymax[i],xmax[i]] for i in range(numofbox)]
if numofbox == 1:
bbox = [[ymin,xmin,ymax,xmax]]
label = [1 for i in range(numofbox)]#有多少个bbox,就有多少个1标签
shape = list(shape)
return image_data, shape, bbox, label, imname
这里所需的变量有 image_data,shape,bbox,label,imname
2生成.tfrecord的编辑对象
def _convert_to_example(image_data, shape, bbox, label,imname):
nbbox = np.array(bbox)
ymin = list(nbbox[:, 0])
xmin = list(nbbox[:, 1])
ymax = list(nbbox[:, 2])
xmax = list(nbbox[:, 3])
#print 'shape: {}, height:{}, width:{}'.format(shape,shape[0],shape[1])
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': int64_feature(shape[0]),
'image/width': int64_feature(shape[1]),
'image/channels': int64_feature(shape[2]),
'image/shape': int64_feature(shape),
'image/object/bbox/ymin': float_feature(ymin),
'image/object/bbox/xmin': float_feature(xmin),
'image/object/bbox/ymax': float_feature(ymax),
'image/object/bbox/xmax': float_feature(xmax),
'image/object/bbox/label': int64_feature(label),
'image/format': bytes_feature('jpeg'),
'image/encoded': bytes_feature(image_data),
'image/name': bytes_feature(imname.tostring()),
}))
return example
3 写入tfrecord
tf_filename = str(i+1) + '.tfrecord'
tfrecord_writer = tf.python_io.TFRecordWriter(FLAGS.path_save + tf_filename)
tfrecord_writer.write(example.SerializeToString())
4使用tfrecord
import tensorflow as tf
import os
slim = tf.contrib.slim
ITEMS_TO_DESCRIPTIONS = {
'image': 'slim.tfexample_decoder.Image',
'shape': 'shape',
'height': 'height',
'width': 'width',
'object/bbox': 'box',
'object/label': 'label'
}
SPLITS_TO_SIZES = {
'train': 858750,
}
NUM_CLASSES = 2
def get_datasets(data_dir,file_pattern = '*.tfrecord'):
file_patterns = os.path.join(data_dir, file_pattern)
print 'file_path: {}'.format(file_patterns)
reader = tf.TFRecordReader
keys_to_features = {
'image/height': tf.FixedLenFeature([1], tf.int64),
'image/width': tf.FixedLenFeature([1], tf.int64),
'image/channels': tf.FixedLenFeature([1], tf.int64),
'image/shape': tf.FixedLenFeature([3], tf.int64),
'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64),
'image/format': tf.FixedLenFeature([], tf.string, default_value='jpeg'),
'image/encoded': tf.FixedLenFeature([], tf.string, default_value=''),
'image/name': tf.VarLenFeature(dtype = tf.string),
}
items_to_handlers = {
'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'),
#'image': slim.tfexample_decoder.Tensor('image/encoded'),
'shape': slim.tfexample_decoder.Tensor('image/shape'),
'height': slim.tfexample_decoder.Tensor('image/height'),
'width': slim.tfexample_decoder.Tensor('image/width'),
'object/bbox': slim.tfexample_decoder.BoundingBox(
['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'),
'object/label': slim.tfexample_decoder.Tensor('image/object/bbox/label'),
#'imaname': slim.tfexample_decoder.Tensor('image/name'),
#'objext/txt': slim.tfexample_decoder.Tensor('image/object/bbox/label_text'),
}
decoder = slim.tfexample_decoder.TFExampleDecoder(
keys_to_features, items_to_handlers)
labels_to_names = None
return slim.dataset.Dataset(
data_sources=file_patterns,
reader=reader,
decoder=decoder,
num_samples=SPLITS_TO_SIZES['train'],
items_to_descriptions=ITEMS_TO_DESCRIPTIONS,
num_classes=NUM_CLASSES,
labels_to_names=labels_to_names)
再通过设定相应的参数,就可以得到自己想要的数据
dataset = sythtextprovider.get_datasets(dataset_dir,file_pattern = file_pattern)
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
num_readers=num_readers,
common_queue_capacity=512 * 16 + 20 * batch_size,
common_queue_min=512 * 16,
shuffle=shuffe)
[image, shape, glabels, gbboxes,height,width] = provider.get(['image', 'shape',
'object/label',
'object/bbox','height','width'])