TFRecord数据集

tensorflow标准的读取数据格式:TFRecord

       可以允许你讲任意的数据转换为TensorFlow所支持的格式, 这种方法可以使TensorFlow的数据集更容易与网络应用架构相匹配。这种建议的方法就是使用TFRecords文件,TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段Features)。你可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter class写入到TFRecords文件。


1.生成TFRecord文件

(1)获取图片数据

(2)填入Example。

example = tf.train.Example(features = tf.train.Features(feature = {
            "label":_int64_feature(index),
            "img_raw":_bytes_feature(img_raw)
        }))
(3)要写入到文件中,先定义writer:
writer = tf.python_io.TFRecordWriter("./xx.tfrecords 要生成的tfrecord文件名")
将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter写入到TFRecords文件。
writer.write(example.SerializeToString()) #协议内存块转换为字符串,再用writer写入TFRecord中

2.读取TFRecord文件


tensorflow读取数据时,将文件名列表交给tf.train.string_input_producer 函数.string_input_producer来生成一个先入先出的队列, 文件阅读器会需要它来读取数据。

string_input_producer 提供的可配置参数来设置文件名乱序和最大的训练迭代数,

QueueRunner会为每次迭代(epoch)将所有的文件名加入文件名队列中,

如果shuffle=True的话, 会对文件名进行乱序处理。这一过程是比较均匀的,因此它可以产生均衡的文件名队列。

即 将文件列表交给该函数,该函数生成队列,文件阅读器从队列中读数据。根据文件的格式选择阅读器的种类,每个阅读器都有对应的read方法,

read(文件名队列),返回字符串标量,返回的量可以被解析器解析,变成张量。


从TFRecords文件中读取数据, 可以使用tf.TFRecordReadertf.parse_single_example解析器。
这个parse_single_example操作可以将Example协议内存块(protocol buffer)解析为张量


(1)定义队列
filename_queue = tf.train.string_input_producer(["catVSdog_train.tfrecords"])

(2)读数据,定义阅读器,使用相应的read方法。
reader = tf.TFRecordReader()
_,serialized_example = reader.read(filename_queue)
(3)解析返回的量
im_features = tf.parse_single_example(serialized_example,
                                   features={
                                       'label': tf.FixedLenFeature([], tf.int64),
                                       'img_raw' : tf.FixedLenFeature([], tf.string),
                                   })
(4)获取数据
image = tf.decode_raw(im_features['img_raw'],tf.uint8)
image = tf.reshape(image,[128,128,3])
label = tf.cast(im_features['label'],tf.int32)

取batch,乱序(shufflle_batch)和不乱序(.batch),这样下面image,label= sess.run(image_batch,label_batch)
img_batch, label_batch = tf.train.shuffle_batch([img, label],
                                                batch_size=30, capacity=2000,
                                                min_after_dequeue=1000)
(5)创建线程并使用QueueRunner对象来预取 的模板
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
   #操作
    coord.request_stop()
    coord.join(threads)

with tf.Session() as sess:
    init_op = tf.initialize_all_variables()
    sess.run(init_op)
    #模板
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    #
    for i in range(20):
        example, l = sess.run([image, label])
        img = Image.fromarray(example, 'RGB')
        img.save(image_path +"\\"+ str(i) + '_''Label_' + str(l) + '.jpg')
        print(example, l)
    #模板
    coord.request_stop()
    coord.join(threads)


参考:
 (1)tensorflow数据读取:http://wiki.jikexueyuan.com/project/tensorflow-zh/how_tos/reading_data.html
(2)数据集制作:
http://www.cnblogs.com/upright/p/6136265.html
https://www.2cto.com/kf/201702/604326.html
https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/examples/how_tos/reading_data/convert_to_records.py
https://github.com/kevin28520/My-TensorFlow-tutorials/blob/master/03%20TFRecord/notMNIST_input.py















  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值