1. Keras/Tensorflow 2.0 自定义数据集 Dataset

在学习Tensorflow的过程中,发现大多数教程都是基于现有的数据集进行训练、优化。

例如:MNIST识别教程,一个

(x_train, y_train), (x_test, y_test) = mnist.load_data()

即可获得训练、测试数据集。

而在解决实际问题时,我们经常面对的是采集到的原始图片信息,这些图片保存在硬盘当中,当模型搭建好以后开始把数据从硬盘加载到内存,然后计算。然而加载数是需要时间的,如果图片数据比较大,那么无疑浪费了很多数据读取的时间。

我们期望的是:将这些图片信息制作成带标签的数据集,并能方便的shuffle、batch,快速、高效的提供给模型进行训练。

本文以一个花卉识别的例子来展示如何利用Tensorflow的pipeline和缓存技术来方便、快捷的实现一个自定义数据集。

1. 获取图片数据:

from __future__ import absolute_import, division, print_function, unicode_literals

import os
import time
import tensorflow as tf
import pathlib
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

data_root_orig = tf.keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
                                         fname='flower_photos', untar=True)

执行完毕后会在~/.keras/datasets/下保存包含5种花卉图片的文件夹:

2. 查看图片:

快速浏览几张图片,以知道你在处理什么:

data_root = pathlib.Path(data_root_orig)
all_image_paths = list(data_root.glob('*/*'))
all_image_paths = [str(path) for path in all_image_paths]
random.shuffle(all_image_paths)

image_count = len(all_image_paths)
print('Image count: ', image_count)

plt.figure('image show')
for n in range(3):
	image_path = random.choice(all_image_paths)
	label = image_path.split('/')[-2]
	image = Image.open(image_path)
	print(image.size)

	plt.subplot(1, 3, n+1)
	plt.title(label)
	plt.imshow(image)
plt.show()

3. 确定每张图片的标签:

label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())
print('label_names: ', label_names) # ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']

label_to_index = dict((name, index) for index, name in enumerate(label_names))
print('label_to_index: ', label_to_index) # {'sunflowers': 3, 'daisy': 0, 'roses': 2, 'tulips': 4, 'dandelion': 1}

all_image_labels = [label_to_index[pathlib.Path(path).parent.name] for path in all_image_paths]
print("First 10 labels indices: ", all_image_labels[:10]) # [2, 2, 2, 2, 3, 4, 1, 1, 3, 2]

3. 读取和格式化图片:

主要工作是通过tf.io.read_file将图片路径名转化为图片张量,并将每个像素值转换为[0 - 1]的范围(方便训练)。

def preprocess_image(img_raw):
	img_tensor = tf.image.decode_jpeg(contents=img_raw, channels=3) # can be used for plt.imshow(img_tensor)
	img_final = tf.image.resize(images=img_tensor, size=[192, 192])
	img_final /= 255.0 # normalize to [0,1] range
	return img_final

def load_and_preprocess_image(path):
	img_raw = tf.io.read_file(path) # can't be used for plt.imshow(img_raw)
	return preprocess_image(img_raw)

def load_and_preprocess_from_path_label(path, label):
	return load_and_preprocess_image(path), label

def load_and_preprocess_image(path):
	img_raw = tf.io.read_file(path)

	return preprocess_image(img_raw)

4. 构建Dataset:

ds = tf.data.Dataset.from_tensor_slices((all_image_paths, all_image_labels))
image_label_ds = ds.map(load_and_preprocess_from_path_label)

all_image_paths和all_image_labels这两个list中,每张图片和其标签是一一对应的,因此可以打包为一个(图片 - 标签)组。

tf.data.Dataset.from_tensor_slices返回的ds具有很多实用的方法用来操作数据集,例如:shuffle、batch、repeat等,方便后来加载进模型进行训练。

5. 加载图片数据集:

为了高效的从硬盘加载进内存,我们采用了Tensorflow的缓存技术,并且在图片数据远大于内存RAM大小时,仍然可以获得较高的性能。

BATCH_SIZE = 32

ds = image_label_ds.cache(filename='./cache.tf-data')
ds = ds.shuffle(buffer_size=image_count, reshuffle_each_iteration=True).repeat()
ds = ds.batch(BATCH_SIZE).prefetch(buffer_size=1)

至此,Dataset数据集构建完毕,可以用来高效的训练模型。

下一篇将以一个迁移学习的例子展示如何利用Dataset来训练模型。

  • 2
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值