TensorFlow的数据读取机制

1. tf.train.slice_input_producer()

原文链接:
https://blog.csdn.net/qq_30666517/article/details/79715045
https://blog.csdn.net/dcrmg/article/details/79776876
tf.train.slice_input_producer([image,label],num_epochs=10)
现已被tf.data.slice_input_producer替代。
随机产生一个图片和标签,num_epochs=10,则表示把所有的数据过10遍,使用完所有的图片数据为一个epoch,这是重复使用10次。上面的用法表示你的数据集和标签已经全部加载到内存中了,如果数据集非常庞大,我们通过这个函数也可以只加载图片的路径,放入图片的path,注意path必须是一个list或者tensorlist。见下面代码实例:

import tensorflow as tf
import glob
import matplotlib.pyplot as plt
import time

datapath=r'path/to/dataset/'
imgpath = glob.glob(datapath+'*.bmp')
# 将路径转化成张量形式
imgpath = tf.convert_to_tensor(imgpath)

# 产生一个队列每次随机产生一张图片地址
# 注意这里要放在数组里面
image = tf.train.slice_input_producer([imgpath])
# 得到一个batch的图片地址
# 由于tf.train.slice_input_producer()函数默认是随机产生一个实例
# 所以在这里直接使用tf.train.batch()直接获得一个batch的数据即可
# 没有必要再去使用tf.trian.shuffle_batch() 速度会慢
img_batch = tf.train.batch([image],batch_size=20,capacity=100)
 
with tf.Session() as sess:
    coord = tf.train.Coordinator()
    thread = tf.train.start_queue_runners(sess,coord)
    i = 0
    try:
        while not coord.should_stop():
            imgs = sess.run(img_batch)
            print(imgs)
            fig = plt.figure()
            for i,path in enumerate(imgs):
                img = plt.imread(path[0].decode('utf-8'))
                axes = fig.add_subplot(5,4,i+1)
                axes.imshow(img)
                axes.axis('off')
            plt.ion()
            plt.show()
            time.sleep(1)
            plt.close()
            i+=1
            if i%10==0:
                break
    except tf.errors.OutOfRangeError:
        pass
    finally:
        coord.request_stop()
    coord.join(thread)

注意路径此时被加载成二进制编码格式了。

2. tf.train.slice_input_producer()

tf.train.slice_input_producer([path])
批量读取图片,得到每个图片的路径后,我们可以加载图片并解码成三维数组的形式(图像的深度必须是3通道或者4通道,笔者实验灰度图像,一直不成功)。当使用tf.train.slice_input_producer()时,加载图片数据的reader使用tf.read_file(filename),直接读取。注意图片记得resize()。见下面代码:

# 用于通过读取图片的path,然后解码成图片数组的形式,最后返回batch个图片数组
import glob
import tensorflow as tf
import matplotlib.pyplot as plt
 
path_list = r'path/to/dataset/'
img_path = glob.glob(path_list+'*.bmp')
img_path = tf.convert_to_tensor(img_path,dtype=tf.string)
 
# 这里img_path,不放在数组里面
# num_epochs = 1,表示将文件下所有的图片都使用一次
# num_epochs和tf.train.slice_input_producer()中是一样的
# 此参数可以用来设置训练的 epochs
image = tf.train.slice_input_producer([img_path],num_epochs=1)
 
 
# load one image and decode img
def load_img(path_queue):
# 创建一个队列读取器,然后解码成数组
#    reader = tf.WholeFileReader()
#    key,value = reader.read(path_queue)
    file_contents = tf.read_file(path_queue[0])
    img = tf.image.decode_bmp(file_contents,channels=1)
	# 这里很有必要,否则会出错
	# 感觉这个地方貌似只能解码3通道以上的图片
    img = tf.image.resize_images(img,size=(100,100))
    # img = tf.reshape(img,shape=(50,50,4))
    return img
img = load_img(image)
print(img.shape)
image_batch = tf.train.batch([img],batch_size=20)
 
with tf.Session() as sess:
    
    # initializer for num_epochs
    tf.local_variables_initializer().run()
    coord = tf.train.Coordinator()
    thread = tf.train.start_queue_runners(sess=sess,coord=coord)
    try:
        while not coord.should_stop():
            imgs = sess.run(image_batch)
            print(imgs.shape)
    except tf.errors.OutOfRangeError:
        print('done')
    finally:
        coord.request_stop()
    coord.join(thread)

3. tf.train.string_input_producer()

tf.train.string_input_producer(path)
传入路径时,不需要放入list中。然后加载图片的reader是tf.WholeFileReader(),其他地方和tf.train.slice_input_producer()函数用法基本类似。见代码:

# 用于通过读取图片的path,然后解码成图片数组的形式,最后返回batch个图片数组
import glob
import tensorflow as tf
 
path_list = r'/media/wsw/文档/pythonfile_withpycharm/SVMLearning/faceLibrary/人脸库/Yale2/'
img_path = glob.glob(path_list+'*.bmp')
img_path = tf.convert_to_tensor(img_path,dtype=tf.string)
 
# 这里img_path,不放在数组里面
# num_epochs = 1,表示将文件下所有的图片都使用一次
# num_epochs和tf.train.slice_input_producer()中是一样的
# 此参数可以用来设置训练的 epochs
image = tf.train.string_input_producer(img_path,num_epochs=1)
 
 
# load one image and decode img
def load_img(path_queue):
    # 创建一个队列读取器,然后解码成数组
    reader = tf.WholeFileReader()
    key,value = reader.read(path_queue)
    img = tf.image.decode_bmp(value,channels=3)
	# 这里很有必要,否则会出错
	# 感觉这个地方貌似只能解码3通道以上的图片
    # img = tf.image.resize_images(img,size=(100,100))
    img = tf.reshape(img,shape=(224,224,3))
    return img
   
img = load_img(image)
print(img.shape)
image_batch = tf.train.batch([img],batch_size=20)
 
with tf.Session() as sess:
    
    # initializer for num_epochs
    tf.local_variables_initializer().run()
    coord = tf.train.Coordinator()
    thread = tf.train.start_queue_runners(sess=sess,coord=coord)
    try:
        while not coord.should_stop():
            imgs = sess.run(image_batch)
            print(imgs.shape)
    except tf.errors.OutOfRangeError:
        print('done')
    finally:
        coord.request_stop()
    coord.join(thread)

4. tf.data.TFRecordDataset

  • 制作17flowers数据集。为减少内存消耗,这里仅将./17flowers/class_name/image_path,即图片文件的路径作为文本写入tfrecord文件。同时写入class_text。
  • 使用tf.data.TFRecordDataset创建迭代器进行读取tfrecord中的图片路径,然后进行解析。
    import tensorflow as tf 
    import numpy as np 
    import os
    from datetime import datetime
    
    dataset_dir = r"E:\MyCollectionFinished\my_vgg_tensorflow\dataset\17flowers"
    class_names = os.listdir(dataset_dir)
    tfrecord_path = os.path.join(dataset_dir, "image_paths.tfrecord")
    writer = tf.python_io.TFRecordWriter(tfrecord_path)
    for class_name in class_names:
        class_path = os.path.join(dataset_dir, class_name)
        if not os.path.isdir(class_path):
            continue
        print(class_path)
        image_paths = os.listdir(class_path)
        for idx, image_path in enumerate(image_paths):
            image_path = os.path.join(class_path, image_path)
            path = tf.train.BytesList(value=[bytes(image_path, encoding="utf-8")])
            cls_text = tf.train.BytesList(value=[bytes(class_name, encoding="utf-8")])
            feature_dict = {"image_path": tf.train.Feature(bytes_list=path),
                           "class_text": tf.train.Feature(bytes_list=cls_text)}
            features = tf.train.Features(feature=feature_dict)
            example = tf.train.Example(features=features)
            writer.write(example.SerializeToString())
    writer.close()
    
    将产生如下输出:
    E:\MyCollectionFinished\my_vgg_tensorflow\dataset\17flowers\0
    E:\MyCollectionFinished\my_vgg_tensorflow\dataset\17flowers\1
    ...
    
  • 进行tfrecord文件读取
    def parse_exmaple(serialized_example):
        features = {"image_path": tf.FixedLenFeature([], tf.string),
                        "class_text": tf.FixedLenFeature([], tf.string)}
        features = tf.parse_single_example(serialized_example, features=features)
        # path = features["image_path"]
        # cls_text = features["class_text"]
        return features
    
  • 创建迭代器,分别输出特征
    tfrecord_path = r"E:\MyCollectionFinished\my_vgg_tensorflow\dataset\17flowers\image_paths.tfrecord"
    dataset = tf.data.TFRecordDataset(tfrecord_path)
    dataset = dataset.map(parse_exmaple)
    iterator = dataset.make_one_shot_iterator()
    with tf.Session() as sess:
        iterator.get_next()
        for i in range(1000):
            if i%100:
                continue
            feature = sess.run(feature_tensor)
            print(feature['image_path'].decode('utf-8'))
    

    注意:feature['image_path'].decode('utf-8'),这里如果不加.decode('utf-8'),输出将不正确。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值