Tensorflow2.x读取自定义数据集(图像分类)

最近在积攒粉丝500,大家帮帮忙,动动小手指关注、点赞、收藏…🙏🙏🙏🙏🙏🙏


一、说明

在深度学习图像分类模型设计的数据集合中,通常包含训练集train set、验证集validation set、测试集test set;

在这里插入图片描述

  • 训练集train: 用于模型学习训练;
  • 验证集valid:用于训练过程中模型评估、调整超参数、监控模型是否发生过拟合等。
  • 测试集test:用于最终评估模型泛化能力。

二、数据集的目录结构

例如猫狗图像分类数据集的目录结构如下:
在这里插入图片描述

train、valid、test文件夹下包含一个个子文件夹,每个子文件夹是一个类别;如train目录下有猫cat类别、dog类别;

每个类别文件夹下包含了对应类别的图像;例如猫cat类别的图像:
在这里插入图片描述

三、读取代码

通常分单独的集合读取。函数代码如下:

3.1 获取集合下的所有图片路径

def get_all_image_paths(image_dir):
    '''
    获取所有图片路径,例如 ['mycatdog2/train/cat/cat_1.jpg', ...]
    image_dir: train/valid/test目录;如:mycatdog2/train
    '''
    data_path = pathlib.Path(image_dir)
    paths = list(data_path.glob('*/*'))     # 图片全路径
    paths = [str(p) for p in paths]
    return paths

3.2 获取类别名称及其数字标签

def get_label_and_index(image_dir):
    '''获取类别名称及其数字标签,例如
        ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
        {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
    '''
    data_path = pathlib.Path(image_dir)
    label_names = sorted(item.name for item in data_path.glob('*/') if item.is_dir())
    label_index = dict((name,index) for index,name in enumerate(label_names))
    return label_names, label_index

3.3 Tensorflow2.x的接口读取数据集

调用以上两个函数和tensorflow2.0接口分批次读取数据集。

def process_image(fpath, label):
    """ 图片预处理 """
    image = tf.io.read_file(fpath)                  # 读取图像
    image = tf.image.decode_jpeg(image,channels=3)  # jpg图像解码
    image = tf.image.resize(image, [112, 112])      # 原始图片大重设为(x, x), AlexNet的输入是224X224
    label = tf.one_hot(label, depth=2)              # 标签转成onehot格式,这里实验是标签2个类别数据
    return image, label

def get_dataset(image_dir, is_shuffle=False, batch_size=64):
    # 获取所有图片路径
    image_paths = get_all_image_paths(image_dir)
    _, label_index = get_label_and_index(image_dir)
    # 每个图片路径名->数字标签
    image_labels = [label_index[pathlib.Path(path).parent.name] for path in image_paths]
    # tensorflow接口创建数据集读取
    ds = tf.data.Dataset.from_tensor_slices((image_paths, image_labels))
    # 回调数据处理
    ds = ds.map(process_image)
    # 洗牌
    if is_shuffle:
        ds = ds.shuffle(buffer_size=len(image_paths))
    # 分批次
    ds = ds.batch(batch_size)
    return ds

3.4 测试

测试一下读取一个batch的数据:

if __name__ == "__main__":
    ds = get_dataset("F:\dataset\mycatdog2\\train")
    for x, y in ds:
        print("x:", x.shape)
        print("y:", y.shape)
        break

如下:
在这里插入图片描述

  • 4
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI学长

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值