tensorflow数据读取——TFRecord

三种办法

  • 预加载数据(Preload):在TensorFlow图中定义常量或变量来保存所有数据(仅适用于数据量比较小的情况)。
  • 供给数据(Feeding):在TensorFlow程序运行的每一步,让Python代码来供给数据。
  • 从文件读取数据:在TensorFlow图的起始,让一个输入管线从文件中读取数。(大数据常用方法)

从文件读取数据

  • 可以使用TFRecord文件(以二进制存储),但是这个要先自己制作。
  • 通过文件名列表产生文件名队列(monodepth就是通过一个存储了图片名字的txt文件结合tf.train.string_input_producer读取数据)。
TFRecord 文件
  • 通过tf.train.Example来写入TFRecord文件, tf.train.Example 包含了一个字典,key是字符串,value为Feature,Feature可以取值为字符串(BytesList )、浮点数列表(FloatList )、整型数列表(Int64List )。
  • TFRecord文件写入步骤:
    1、首先要获取我们需要转化的数据
    2、将数据填入到Example PB, 并且将Example PB 转化为一个字符串
    3、通过 tf.python_io.TFRecordWriter 将字符串写入TFRecord 文件
def imageTotfrecord(recordName,filePath):
    writer = tf.python_io.TFRecordWriter(recordName)
    for root, dir, files in os.walk(filePath):
        for filename in files:
            label = filename.split('_')[0]
            label = bytes(label, encoding = "utf8") #将字符串转为bytes
            img = Image.open(os.path.join(root, filename))
            img_raw = img.tobytes()
            example = tf.train.Example(features=tf.train.Features(feature={
            "label": tf.train.Feature(bytes_list=tf.train.BytesList(value=[label])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
            }))
            writer.write(example.SerializeToString())  #序列化为字符串
    writer.close()
  • 如何将字符串(label)写入TFRecord文件
    先用byte将string转为byte,再采用BytesList格式
# str to bytes
bytes(s, encoding = "utf8")
# bytes to str
str(b, encoding = "utf-8")
  • 读取TFRecord文件
def readTFRecord(recordName):
    reader = tf.TFRecordReader()
    filename_queue = tf.train.string_input_producer([recordName])
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
        serialized_example,
        features={
            'img_raw': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.string),
        })
    img = tf.decode_raw(features['img_raw'], tf.uint8) # 数据转换
    label = tf.cast(features['label'], tf.string) 
    sess = tf.Session()
    #启动多线程处理输入数据
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    #每次运行可以读取TFRecord文件中的一个样例。当所有样例都读完之后,再次样例中的程序会重头读取
    for i in range(2):
        _, mlabel = sess.run([img, label])
        mlabel= str(mlabel, encoding = "utf-8")
        print(mlabel)
参考资料
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值