TensorFlow学习笔记之读取数据概述

概述

TensorFlow读取数据,官网介绍的方法有3种:

  • 预加载数据 (Preloaded data): 在Graph中定义常量或变量来保存数据。
  • 供给数据 (Feeding): 在Graph运行中将Python代码产生好的数据供给TF后端。 
  • 从文件读取数据 (Reading from file): 在Graph的起始, 利用输入管线直接从文件中读取数据(最常用)。

看官网上这么写,还是不太清楚这几种方法到底是怎么实现的,于是查了些资料,稍作整理一番。

先理解一下TensorFlow的工作模式:

TF底层也就是计算核心模块和运行框架是用C++写的,同时提供API给Python (TF也提供了C++、Java、Go的API, 没用过, 不管),然后,Python调用这些API设计网络模型Graph,交给后端运算执行。所以Python负责Design,C++负责Run。


预加载数据 

仅适用于可以完全加载到内存中的小数据集。有两种方法:

  1. 存储在常量中。
  2. 存储在变量中,且初始化后值不变。

数据存到常量中:

import tensorflow as tf  
x_data = [2, 3, 4]
y_data = [4, 0, 1]
x = tf.constant(x_data)  
y = tf.constant(y_data) 
with tf.Session() as sess:
    ...
    sess.run(x)
    sess.run(y)

数据存到变量中,就需要在数据流图建立后初始化这个变量,而且值不能再被改变:(这里也用到了占位符,但重点是使用了变量存储数据)

import tensorflow as tf  
x_data = [2, 3, 4]
y_data = [4, 0, 1]
x_initializer = tf.placeholder(dtype=x.dtype,shape=x.shape)
y_initializer = tf.placeholder(dtype=y.dtype,shape=y.shape)
x = tf.Variable(x_initializer,trainable=False,collections=[])
y = tf.Variable(y_initializer,trainable=False,collections=[])
with tf.Session() as sess:
    ...
    sess.run(x.initializer, feed_dict={x_initializer: x_data})
    sess.run(y.initializer, feed_dict={y_initializer: y_data})

设置 trainable=False 可以防止该变量被数据流图的 GraphKeys.TRAINABLE_VARIABLES 收集, 这样在训练的时候变量就不会和其他网络参数一样被更新; 设置 collections=[] 可以防止被 GraphKeys.VARIABLES 收集做为保存和恢复的中断点。


供给数据 

TensorFlow有数据供给机制,允许在Graph中将数据注入到任一张量。python代码产生的数据可以通过此方式直接输入到Graph。

设计placeholder节点的唯一意图就是为了提供数据供给(feeding)的方法。placeholder节点被声明的时候是未初始化的,不包含数据,需要通过run()或者eval()函数输入feed_dict 参数, 才能启动运算。

import tensorflow as tf  
x1 = tf.placeholder(tf.int16)  
x2 = tf.placeholder(tf.int16)  
y = tf.add(x1, x2)
#python产生数据 
data1 = [2, 3, 4]  
data2 = [4, 0, 1]  
with tf.Session() as sess:  
    sess.run(y, feed_dict={x1:data1, x2:data2})
#或者用eval()
#with tf.Session():
#    y.eval(feed_dict={x1:data1, x2:data2})


从文件读取数据

根据文件格式, 选择对应的文件阅读器, 然后将文件名队列提供给阅读器的read方法。read输出的key表征输入的文件和纪录,而字符串标量value可以被不同的解析器解码成张量样本。它就是我们读到的数据

从csv文件读取数据 

从CSV文件中读取数据, 需要使用 TextLineReader和 decode_csv,使用一个reader的写法如下:

# -*- coding:utf-8 -*-  

import tensorflow as tf

#生成文件名队列
filenames = ['num1.csv', 'num2.csv']  
filename_queue = tf.train.string_input_producer(filenames, shuffle=True)  
#定义阅读器
reader = tf.TextLineReader()  
key, value = reader.read(filename_queue)  
#定义解码器,一次读一行
example, label = tf.decode_csv(value, record_defaults=[[1], [1]])
#使用tf.train.batch()相当于多加了一个样本队列和一个QueueRunner
#example, label = tf.train.batch([example,label],batch_size=3) 
#example, label = tf.train.shuffle_batch([example,label],batch_size=3,capacity=100,min_after_dequeue=10)

with tf.Session() as sess:  
    coord = tf.train.Coordinator()  #创建一个协调器,管理线程  
    threads = tf.train.start_queue_runners(coord=coord)
    for i in range(10):
        example_batch, label_batch = sess.run([example, label])
        print(example_batch, label_batch)
    coord.request_stop() 
    coord.join(threads)

上面的写法是一个reader,一个样本。

如果加上tf.train.batch可以实现一个reader,batch_size个样本。

如果用tf.train.shuffle_batch的话,也可以读batch_size个样本,并且打乱顺序。

使用多个reader的写法如下:

# -*- coding:utf-8 -*-

import tensorflow as tf

filenames = ['num1.csv', 'num2.csv']  
filename_queue = tf.train.string_input_producer(filenames, shuffle=True)  
reader = tf.TextLineReader()  
key, value = reader.read(filename_queue) 
#定义了多个解码器,每个解码器跟一个reader相连,这里reader设置为2 
example_list = [tf.decode_csv(value, record_defaults=[[1], [1]]) for _ in range(2)]
#使用tf.train.batch_join(),可以使用多个reader,并行读取数据,每个Reader使用一个线程
example, label = tf.train.batch_join(example_list, batch_size=3)  

with tf.Session() as sess:  
    coord = tf.train.Coordinator()  
    threads = tf.train.start_queue_runners(coord=coord)  
    for i in range(10):  
        example_batch, label_batch = sess.run([example, label])  
        print(example_batch, label_batch)
    coord.request_stop()  
    coord.join(threads) 

tf.train.batch与tf.train.shuffle_batch函数是单个Reader读取,可以多线程(即batch_size>1)。tf.train.batch_join与tf.train.shuffle_batch_join 可以设置多Reader读取,每个Reader使用一个线程。至于两种方法的效率,单Reader时,2个线程就达到了速度的极限。多Reader时,2个Reader就达到了极限。所以并不是线程越多越快,太多的线程会使效率下降。

tf.decode_csv()中的 record_defaults = [[1], [1]]:record_defaults是解析的模板,每行有几列单元就有几个[1];整型数值解析标准是[1],浮点型是[1.0],字符型是['null']。

从图像文件读取数据 

首要目标是获得图像名列表。

可以把图像文件路径存到xlsx或txt文件中,一行一个样本,再用python方法读取文件名列表。

可以用tf.gfile直接获取图像文件夹内所有文件名。(下方代码是这种)

# -*- coding:utf-8 -*-

import tensorflow as tf
import os.path

filenames = tf.gfile.ListDirectory('image_dir')
filenames = [os.path.join('image_dir', f) for f in filenames]  #文件完整路径

filename_queue = tf.train.string_input_producer(filenames, shuffle=True)  
#定义阅读器
reader = tf.WholeFileReader()  
key, value = reader.read(filename_queue) 
#定义解码器
image = tf.image.decode_jpeg(value, channels=3)
image = tf.reshape(image, [image_size, image_size, 3])
image_batch = tf.train_batch(image, batch_size=3) 

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    for i in range(10):
        example_batch = sess.run(image_batch)  
        print(example_batch.shape)
    coord.request_stop()  
    coord.join(threads) 

从TFRecords文件读取数据

这种方式先要将你的数据转换为tensorflow标准格式TFRecords文件,它实际上是一种二进制文件,虽然不好理解,但能更好的利用内存,更容易与TF网络架构匹配。

TFRecords文件包含了 tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。你可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串,并且通过 tf.python_io.TFRecordWriter 类写入到TFRecords文件。

从TFRecords文件中读取数据, 可以使用 tf.TFRecordReader 的 tf.parse_single_example 解析器。这个操作可以将Example协议内存块(protocol buffer)解析为张量。

import os
import tensorflow as tf 
from PIL import Image

#classes是根据数据类型自定义的列表
#比如我把所有图像分类存放在class_0、class_1、class_2文件夹里
#那么classes=['class_0','class_1','class_2']
writer = tf.python_io.TFRecordWriter("train.tfrecords")
for index, name in enumerate(classes):
    class_path = dataset_dir + name + "\"
    #举个例子则class_path可能是“E:\image_dataset\class_0\”
    for img_name in os.listdir(class_path):
        img_path = class_path + img_name
        img = Image.open(img_path)
        img = img.resize((224, 224))
        img_raw = img.tobytes()
        example = tf.train.Example(features=tf.train.Features(feature={"label":tf.train.Feature(int64_list=tf.train.Int64List(value=[index])), 'img_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))}))
        writer.write(example.SerializeToString()) #序列化为字符串
writer.close()

于是数据相关的信息都被存到了一个文件中,包括example和label。

生成了TFRecords文件后,再使用队列读取数据,代码如下:

#生成文件名队列
filename = "train.tfrecords"
filename_queue = tf.train.string_input_producer([filename])
#定义阅读器
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)   
#返回文件名和文件
features = tf.parse_single_example(serialized_example, features={'label': tf.FixedLenFeature([], tf.int64), 'img_raw' : tf.FixedLenFeature([], tf.string)})
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [224, 224, 3])
image = tf.cast(image, tf.float32)*(1./255)
label = tf.cast(features['label'], tf.int32)

images, labels = tf.train.shuffle_batch([image, label], batch_size=30, capacity=2000, min_after_dequeue=1000)

init = tf.initialize_all_variables()
with tf.Session() as sess:
    sess.run(init)
    threads = tf.train.start_queue_runners(sess=sess)
    for i in range(10):
        image_batch, label_batch= sess.run([images, labels])
        print(image_batch.shape, label_batch)

* 确实比前两种方法麻烦,但是既然是官方标准格式,它总有自己的好处。

* 因为TF的graph能够记住状态(state),就是说TFRecordReader能够记住tfrecord的位置,这样才能不断返回下一个文件。因此在使用之前,必须初始化整个graph,tf.initialize_all_variables()的作用就是初始化。

* sess.run()时队列才执行,TFRecordReader会不断弹出队里中文件名,直到队列为空。




  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值