采用猴子数据集:一共10个类
代码如下
import tensorflow as tf
import pathlib
import random
path = './monkey/training/training'
data_path = pathlib.Path(path)
all_images_path = list(data_path.glob('*/*'))
all_images_path = [str(path) for path in all_images_path] # 所有图片路径名存入列表
random.shuffle(all_images_path) # 打散
print(len(all_images_path))
print(all_images_path[:5]) # 打印前五个
# 开始制作标签
label_names = sorted(item.name for item in data_path.glob('*/') if item.is_dir())
print(label_names) # 打印类别名 注:下一步是制作与类别名对应的标签
label_to_index = dict((name, index) for index, name in enumerate(label_names))
all_image_labels = [label_to_index[pathlib.Path(path).parent.name] for path in all_images_path]
for image, label in zip(all_images_path[:5], all_image_labels[:5]):
print(image, '-----', label)
# 创建Dataset
ds = tf.data.Dataset.from_tensor_slices((all_images_path, all_image_labels))
# 定义图片记载的函数
def load_and_preprocess_from_path_label(my_path, my_label):
img = tf.io.read_file(my_path)
img = tf.image.decode_jpeg(img, channels=1)
img = tf.image.resize(img, [150, 150])
img = img / 255.0
return img, my_label
image_label_ds = ds.map(load_and_preprocess_from_path_label)
print(image_label_ds)
参考:https://www.cnblogs.com/chenhuabin/p/11863889.html