引言
TensorFlow1.x与TensorFlow2.x在数据读取方面存在如下异同点。
相同点:
- 深度学习有大量图片处理操作,图像处理相关API并没有太大的变化。诸如 tf.image.decode_image()、tf.image.resize_image() 等图像处理相关的 API 在TF1和TF2中都可以使用,因为它们属于 TensorFlow 的核心图像处理函数,不依赖于数据集对象。
不同点:
- 数据集对象:TF2 引入了 tf.data.Dataset 这个类来构建数据流水线进行数据处理,而在 TF1 中并不存在这个类。TF1 主要使用 tf.data.QueueRunner,或者 tf.placeholder 和 feed_dict 来传递数据。
- Eager Execution:TF2 默认启用了 Eager Execution 模式,允许立即执行操作并获得结果,而 TF1 则默认使用静态图模式,需要先构建计算图然后执行。
本文以读取图片文件案例展示tf1和tf2在数据读取方面的不同。
TF1版本案例
import tensorflow as tf
import os
def picread(file_list):
# 1、构造文件名队列
# 返回文件名队列
file_queue = tf.train.string_input_producer(file_list)
print("file_queue:\n", file_queue)
# 2、构造一个图片读取器,去队列中读取样本
# 返回reader实例,调用read方法读取内容,key, value
reader = tf.WholeFileReader()
print("reader:\n", reader)
key, value = reader.read(file_queue)
print("key:\n", key)
print("value:\n", value)
# 3、对样本内容进行解码
image = tf.image.decode_jpeg(value)
print("image:\n", image)
# 处理图片的大小,形状
image_resize = tf.image.resize_images(image, [200, 200])
print("image_resize:\n", image_resize)
# 设置固定形状,这里可以使用静态形状API去修改
image_resize.set_shape([200, 200, 3])
# 4、批处理图片数据
# 每个样本的形状必须全部定义
image_batch = tf.train.batch([image_resize], batch_size=100, num_threads=1, capacity=100)
print("image_batch:\n", image_batch)
return image_batch
if __name__ == "__main__":
# 生成路径/文件名的列表
filename = os.listdir("./dog")
file_list = [os.path.join("./dog/", file) for file in filename]
image_batch = picread(file_list)
# 开启会话打印内容
with tf.Session() as sess:
# 创建线程协调器
coord = tf.train.Coordinator()
# 开启子线程去读取数据
# 返回子线程实例
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
# 获取样本数据去训练
print(sess.run(image_batch))
# 关闭子线程,回收
coord.request_stop()
coord.join(threads)
TF2版本案例
import tensorflow as tf
import os
def picread(file_list):
# 1、构造文件名队列
# 返回文件名队列
file_queue = tf.data.Dataset.from_tensor_slices(file_list)
print("file_queue:\n", file_queue)
# 2、构造一个图片读取器,去队列中读取样本
# 返回reader实例,调用read方法读取内容,key, value
def read_file(filename):
image = tf.io.read_file(filename)
return filename, image
dataset = file_queue.map(read_file)
print("dataset:\n", dataset)
# 3、对样本内容进行解码
def decode_image(filename, image):
image = tf.image.decode_jpeg(image, channels=3)
return filename, image
dataset = dataset.map(decode_image)
print("decoded dataset:\n", dataset)
# 处理图片的大小,形状
def resize_image(filename, image):
image = tf.image.resize(image, [200, 200])
image.set_shape([200, 200, 3])
return filename, image
dataset = dataset.map(resize_image)
print("resized dataset:\n", dataset)
# 4、批处理图片数据
# 每个样本的形状必须全部定义
dataset = dataset.batch(100) # 目标目录中有100张图片
print("batched dataset:\n", dataset)
return dataset
if __name__ == "__main__":
# 生成路径/文件名的列表
filenames = os.listdir("./dog")
file_list = [os.path.join("./dog/", file) for file in filenames]
image_dataset = picread(file_list)
# 获取样本数据去训练
for filenames, images in image_dataset:
print(filenames, images)
......
注意,这里面使用了 map 方法进行隐式参数传递,当在 map 方法中调用 read_file、decode_image 和 resize_image 这些函数时,Dataset 会自动将每个元素作为参数传递给这些函数。