最近在积攒粉丝500,大家帮帮忙,动动小手指关注、点赞、收藏…🙏🙏🙏🙏🙏🙏
一、说明
在深度学习图像分类模型设计的数据集合中,通常包含训练集train set、验证集validation set、测试集test set;
- 训练集train: 用于模型学习训练;
- 验证集valid:用于训练过程中模型评估、调整超参数、监控模型是否发生过拟合等。
- 测试集test:用于最终评估模型泛化能力。
二、数据集的目录结构
例如猫狗图像分类数据集的目录结构如下:
train、valid、test文件夹下包含一个个子文件夹,每个子文件夹是一个类别;如train目录下有猫cat类别、dog类别;
每个类别文件夹下包含了对应类别的图像;例如猫cat类别的图像:
三、读取代码
通常分单独的集合读取。函数代码如下:
3.1 获取集合下的所有图片路径
def get_all_image_paths(image_dir):
'''
获取所有图片路径,例如 ['mycatdog2/train/cat/cat_1.jpg', ...]
image_dir: train/valid/test目录;如:mycatdog2/train
'''
data_path = pathlib.Path(image_dir)
paths = list(data_path.glob('*/*')) # 图片全路径
paths = [str(p) for p in paths]
return paths
3.2 获取类别名称及其数字标签
def get_label_and_index(image_dir):
'''获取类别名称及其数字标签,例如
['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
'''
data_path = pathlib.Path(image_dir)
label_names = sorted(item.name for item in data_path.glob('*/') if item.is_dir())
label_index = dict((name,index) for index,name in enumerate(label_names))
return label_names, label_index
3.3 Tensorflow2.x的接口读取数据集
调用以上两个函数和tensorflow2.0接口分批次读取数据集。
def process_image(fpath, label):
""" 图片预处理 """
image = tf.io.read_file(fpath) # 读取图像
image = tf.image.decode_jpeg(image,channels=3) # jpg图像解码
image = tf.image.resize(image, [112, 112]) # 原始图片大重设为(x, x), AlexNet的输入是224X224
label = tf.one_hot(label, depth=2) # 标签转成onehot格式,这里实验是标签2个类别数据
return image, label
def get_dataset(image_dir, is_shuffle=False, batch_size=64):
# 获取所有图片路径
image_paths = get_all_image_paths(image_dir)
_, label_index = get_label_and_index(image_dir)
# 每个图片路径名->数字标签
image_labels = [label_index[pathlib.Path(path).parent.name] for path in image_paths]
# tensorflow接口创建数据集读取
ds = tf.data.Dataset.from_tensor_slices((image_paths, image_labels))
# 回调数据处理
ds = ds.map(process_image)
# 洗牌
if is_shuffle:
ds = ds.shuffle(buffer_size=len(image_paths))
# 分批次
ds = ds.batch(batch_size)
return ds
3.4 测试
测试一下读取一个batch的数据:
if __name__ == "__main__":
ds = get_dataset("F:\dataset\mycatdog2\\train")
for x, y in ds:
print("x:", x.shape)
print("y:", y.shape)
break
如下: