使用tf.data加载数据
动机
此篇教程指于如何使用tf.data加载数据,本文所用图片分布在各个目录之中,每一个目录拥有一类图片。
目标
- 了解数据集所需要的信息
- 建立tf.data.Dataset对象
图像数据集下载后需要了解的信息:
- 图像的数量、格式和路径,标签数量、格式和标签路径
- 图像的尺寸以及通道数
- 图像使用图像库随机显示部分图像
建立tf.data.Dataset
最简单的建立tf.data.Dataset的方式是使用from_tensor_slices方法。例如,创建图像数据路径Dataset实例。
创建Dataset实例后,使用map函数依据图像路径进而对图像进行处理。处理包括加载图像、resize,修改格式和压缩等。
假设all_image_paths是一个包含所有图像路径的列表。代码链接https://www.tensorflow.org/tutorials/load_data/images
1.使用gfile读图片,decode输出是Tensor
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
print(tf.__version__)
image_raw = tf.gfile.FastGFile('test/a.jpg','rb').read() #bytes
img = tf.image.decode_jpeg(image_raw) #Tensor
2.使用WholeFileReader输入queue,decode输出是Tensor,eval后是ndarray(被tf.data替代)
3.使用read_file,decode输出是Tensor,eval后是ndarray
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
print(tf.__version__)
image_value = tf.read_file('test/a.jpg')
img = tf.image.decode_jpeg(image_value, channels=3)