首先先说明一下,目前我在做图片分类的时候,feed进入的图片数据是怎么做的?
1,有一个函数可以读取出所有图片的路径,并有一个label映射,这个函数本身还有打乱数据的功能
2,训练时,为了防止OOM,每次读取一部分图片进行训练。
这个过程倒是挺简单,但是问题在于,所有的数据都很分散,每次新增图片,或者删减图片都不清楚,并且如果要迁移到一台新的训练机器上数据量越大迁移时间越长。并且训练时候读取图片什么的耗时还是蛮长的。而TFRecord可以很好的解决这些问题,将所有的数据组成在一起,并且tf对其读速度很快,减少了维护的成本,其也会将样本大小减少很多。缺点就是,每次新增图片都要修改这个record文件。
首先TFRecord简介:
- 是一种数据格式,用于存储二进制记录,其底层采用的是protocol buffer【https://developers.google.cn/protocol-buffers/】
如何使用?怎么写,怎么读?
写:
# 为具体写入文件路径
# tf.python_io tf专门给python做的用来操作TFRecord类型的方法
1,writer = tf.python_io.TFRecordWriter(path)
# 采用tf官方用例 前面的都是类型转化函数
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# writer = tf.python_io.TFRecordWriter("test2.tfrecord")
def serialize_example(lable, image):
'''
Creates a tf.Example message ready to be written to a file.
'''
# Create a dictionary mapping the feature name to the tf.Example-compatible
# data type.
feature = {
'lable': _int64_feature(lable),
'image': _bytes_feature(image.tobytes())
}
# Create a Features message using tf.train.Example.
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializeToString()
file = "test1.jpg"
image = cv2.imread(file)
result = serialize_example(1,image)
3, writer.writer(result)
4, writer.close()
读:
reader = tf.TFRecordReader()
filename_queue = tf.train.string_input_producer(["test2.tfrecord"])
key, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features={
"image": tf.FixedLenFeature([], tf.string),
"lable": tf.FixedLenFeature([], tf.int64),
})
image_raw = tf.decode_raw(features['image'], tf.uint8)
image_raw = tf.reshape(image_raw, [100, 100, 3])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
im = sess.run(image_raw)
coord.request_stop()
coord.join(threads)
print(im)