Tensorflow中的文件读取流程

'''
本文主要介绍文件读取的主要流程
使用 多线程 + 队列 的方式

文件读取流程:

(1) 构造文件队列
    file_queue = tf.train.string_input_producer(string_tensor, shuffle=True)
    sting_tensor 是文件名+路径
    shuffle = True 打乱文件顺序

(2) 读取与解码
    tf.WholeFileReader 用于读取图片实例
    tf.TextLineReader 用于读取文本文件csv格式,默认按行读取
    tf.FixedLengthRecordReader 用于读取二进制文件
    tf.TFRecordReader 用于读取TFRecords文件
    
(3) 批处理队列
    tf.train.batch(image, batch_size, num_threads, capacity)
'''

# -------------------------------------------------------分割线-------------------------------------------------------

import tensorflow.compat.v1 as tf
import os

tf.disable_eager_execution()
os.environ['TF_CPP_MIN_LOGLEVEL']='2'

def file_read(all_file):

    # 1.构造文件队列
    filequeue = tf.train.string_input_producer(all_file)

    # 2. 读取与解码
    reader = tf.WholeFileReader()  # Reader是一个读取器实例
    key, value = reader.read(filequeue) # read是一个tensors元组,其中key是文件名, value是一个样本的原始编码形式

    # 解码
    image = tf.image.decode_jpeg(value) # 此时的image是三维数组

    '''
    图像形状和类型的修改
    如果要对图像进行批处理,就需要统一图片的类型和格式
    '''
    image_resize = tf.image.resize_images(image, [200, 200]) # 这里的200,200是长宽
    image_resize.set_shape(shape=[200, 200, 3]) # 200,200,3是长,宽,通道数

    # 3. 批处理
    image_batch = tf.train.batch([image_resize], batch_size=100, num_threads=1, capacity=100)

    # 开启会话
    with tf.Session() as sess:
        # 开启线程
        # 线程协调员
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        key, value, image, image_resize = sess.run([key, value, image, image_resize])  # 如果没有开启线程,这一步会报错
        print('key:\n', key)
        print('value:\n', value)
        print('image:\n', image)
        print('image_resize:\n', image_resize)

        # 回收线程
        coord.request_stop()
        coord.join(threads)

# -------------------------------------------------------分割线-------------------------------------------------------

if __name__ == '__main__':

    # 构造路径+文件名的列表
    dirname = './test_picture'
    all_file = []
    filelist = os.listdir(dirname) # 以列表的形式返回dirname目录下的所有文件

    for file in filelist:
        if file.endswith('.JPG'):   # 用了endswith,说明file是一个字符串
            filename = os.path.join(dirname, file)
            all_file.append(filename)

    print(all_file)

    file_read(all_file)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值