从图片文件中生成一个tf.data.Dataset的类:自动划分训练集和验证集。
源码:tensorflow2.11版本:
def image_dataset_from_directory(
directory,
labels="inferred",
label_mode="int",
class_names=None,
color_mode="rgb",
batch_size=32,
image_size=(256, 256),
shuffle=True,
seed=None,
validation_split=None,
subset=None,
interpolation="bilinear",
follow_links=False,
crop_to_aspect_ratio=False,
**kwargs,
):
参数:
- directory: 数据所在目录。如果标签是“inferred”(默认),则它应该包含子目录,每个目录包含一个类的图像。否则,将忽略目录结构。
- labels: “inferred”(标签从目录结构生成),或者是整数标签的列表/元组,其大小与目录中找到的图像文件的数量相同。标签应根据图像文件路径的字母顺序排序(通过Python中的os.walk(directory)获得)。
- label_mode:
- 'int':表示标签被编码成整数(例如:
sparse_categorical_crossentropy
loss),y_true被编码成一个整数值。 - ‘categorical’指标签被编码为分类向量(例如:
categorical_crossentropy
loss),y_true是一个多分类的结果。 - ‘binary’意味着标签(只能有2个)被编码为值为0或1的float32标量(例如:binary_crossentropy)。
- None(无标签)
- class_names: 仅当“labels”为“inferred”时有效。这是类名称的明确列表(必须与子目录的名称匹配)。用于控制类的顺序(否则使用字母数字顺序)。
- color_mode: "grayscale"、"rgb"、"rgba"之一。默认值:"rgb"。图像将被转换为1、3或者4通道。
- batch_size: 数据批次的大小。默认值:32
- image_size: 从磁盘读取数据后将其重新调整大小。默认:(256,256)。由于管道处理的图像批次必须具有相同的大小,因此该参数必须提供。
- shuffle: 是否打乱数据。默认值:True。如果设置为False,则按字母数字顺序对数据进行排序。
- seed: 用于shuffle和转换的可选随机种子。
- validation_split: 0和1之间的可选浮点数,可保留一部分数据用于验证。
- subset: "training"或"validation"之一。仅在设置validation_split时使用。
- interpolation: 字符串,当调整图像大小时使用的插值方法。默认为:
bilinear。支持bilinear
,nearest
,bicubic
,area
,lanczos3
,lanczos5
,gaussian
,mitchellcubic
.。 - follow_links: 是否访问符号链接指向的子目录。默认:False。
生成器的Return返回值:
一个tf.data.Dataset对象。
- 如果label_mode为None,它将生成float32张量,其shape为(batch_size, image_size[0], image_size(1), num_channels),并对图像进行编码(有关num_channels的规则,参见下文)。
- 否则,将生成一个元组(images, labels),其中图像的shape为(batch_size, image_size[0], image_size(1), num_channels),并且labels遵循下面描述的格式。
关于labels格式规则:
- 如果label_mode 是 int, labels是形状为(batch_size, )的int32张量
- 如果label_mode 是 binary, labels是形状为(batch_size, 1)的1和0的float32张量。
- 如果label_mode 是 categorial, labels是形状为(batch_size, num_classes)的float32张量,表示类索引的one-hot编码。(多分类)
自定义图片数据集生成器:
原始数据集图片文件夹的结构:data文件夹中有多少个子文件夹,就表明有多少个类。
文件命名建议使用英文。
def data_generate(data_dir,batch_size,img_height,img_width):
train_dataset=tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
label_mode='categorical',#生成多分类;标签被编码成分类向量; binary ;
validation_split=0.2,
batch_size=batch_size,
shuffle=True,
seed=120,
subset='training',
image_size=(img_height,img_width),
# color_mode='rgb' #默认图片的是rgb
)
val_dataset=tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
label_mode='categorical',
validation_split=0.2,
batch_size=batch_size,
seed=120,
shuffle=True,
subset='validation',
image_size=(img_height,img_width)
)
class_names=train_dataset.class_names #对应数据集的labels 值
return train_dataset,val_dataset,class_names
调用时:只传入最外层的文件夹:data即可
#生成数据集:
train_dataset,val_dataset,class_names=data_generate('../data',10,224,224)
生成train_dataset是一个元组:(images,labels):
labels:
['lavender', 'peach_blossom', 'roses', 'sunflowers', 'tulip']