原代码来自tensorflow
Tensorflow学习笔记 (用 tf.data 加载图片)
本教程提供一个如何使用 tf.data 加载图片的简单例子。
导入模块,配置
import tensorflow as tf
tf.data用于数据集的构建与预处理
AUTOTUNE = tf.data.experimental.AUTOTUNE
下载并检查数据集
从origin网址中下载文件,命名为’flower_photos’,untar=True表示对文件进行解压。
通过pathlib.Path(data_root_orig)获得文件的路径(虽然data_root_orig 也表示下载的文件路径,但pathlib.Path可以支持不同的操作系统)
import pathlib
data_root_orig = tf.keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz', fname='flower_photos', untar=True)
data_root = pathlib.Path(data_root_orig)
print(data_root)
查看data_root路径下的文件
data_root路径下有5个子文件夹和1txt文件。
for item in data_root.iterdir():
print(item)
data_root.路径下有6个子文件,data_root.glob(’/’)表示读取子文件夹中的所有图片。
import random
all_image_paths = list(data_root.glob('*/*'))
all_image_paths
列表中的部分内容如下所示
[WindowsPath(‘C:/Users/HUAWEI/.keras/datasets/flower_photos/daisy/100080576_f52e8ee070_n.jpg’),
WindowsPath(‘C:/Users/HUAWEI/.keras/datasets/flower_photos/daisy/10140303196_b88d3d6cec.jpg’),
去掉WindowsPath,只留下图片的路径
all_image_paths = [str(path) for path in all_image_paths]
all_image_paths
打乱图片路径顺序
查看一共有多少图片
random.shuffle(all_image_paths)
image_count = len(all_image_paths)
image_count
检查图片
打开data_root路径下的"LICENSE.txt"文件,编码为’utf-8’,读取文件第4行以后的内容。
将列表中的每一项以’ CC-BY’作为分隔符分开。
import os
attributions = (data_root /"LICENSE.txt").open(encoding='utf-8').readlines()[4:]
attributions = [line.split(' CC-BY') for line in attributions]
将attributions变成字典
attributions = dict(attributions