以anime数据集为例:
import multiprocessing
import tensorflow as tf
def batch_dataset(dataset,
batch_size,
drop_remainder=True,
n_prefetch_batch=1,
filter_fn=None,
map_fn=None,
n_map_threads=None,
filter_after_map=False,
shuffle=True,
shuffle_buffer_size=None,
repeat=None):
# set defaults
if n_map_threads is None:
n_map_threads = multiprocessing.cpu_count()
if shuffle and shuffle_buffer_size is None:
shuffle_buffer_size = max(batch_size * 128, 2048) # set the maximum buffer size as 2048
# [*] it is efficient to conduct `shuffle` before `map`/`filter` because `map`/`filter` is sometimes costly
if shuffle:
dataset = dataset.shuffle(shuffle_buffer_size)
if not filter_after_map:
if filter_fn:
dataset = dataset.filter(filter_fn)
if map_fn:
dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)
else: # [*] this is slower
if map_fn:
dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)
if filter_fn:
dataset = dataset.filter(filter_fn)
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
dataset = dataset.repeat(repeat).prefetch(n_prefetch_batch)
return dataset
def memory_data_batch_dataset(memory_data,
batch_size,
drop_remainder=True,
n_prefetch_batch=1,
filter_fn=None,
map_fn=None,
n_map_threads=None,
filter_after_map=False,
shuffle=True,
shuffle_buffer_size=None,
repeat=None):
"""Batch dataset of memory data.
Parameters
----------
memory_data : nested structure of tensors/ndarrays/lists
"""
dataset = tf.data.Dataset.from_tensor_slices(memory_data) # 将路径转换为tensor类型
dataset = batch_dataset(dataset,
batch_size,
drop_remainder=drop_remainder,
n_prefetch_batch=n_prefetch_batch,
filter_fn=filter_fn,
map_fn=map_fn,
n_map_threads=n_map_threads,
filter_after_map=filter_after_map,
shuffle=shuffle,
shuffle_buffer_size=shuffle_buffer_size,
repeat=repeat)
return dataset
def disk_image_batch_dataset(img_paths,
batch_size,
labels=None,
drop_remainder=True,
n_prefetch_batch=1,
filter_fn=None,
map_fn=None,
n_map_threads=None,
filter_after_map=False,
shuffle=True,
shuffle_buffer_size=None,
repeat=None):
"""Batch dataset of disk image for PNG and JPEG.
Parameters
----------
img_paths : 1d-tensor/ndarray/list of str
labels : nested structure of tensors/ndarrays/lists
"""
if labels is None: # 此时图片数据都还没有读进内存
memory_data = img_paths
else:
memory_data = (img_paths, labels)
import tensorflow_io as tfio
def parse_fn(path, *label): # 将图片数据读进内存
img = tf.io.read_file(path)
img = tf.image.decode_jpeg(img, channels=3) # fix channels to 3
# 读取医学图像dicom个数的数据,使用的api是tfio.image.decode_dicom_image()
# 需要先使用 img = image_bytes = tf.io.read_file('xx.dcm')将dicom数据读进内存
# img = tfio.image.decode_dicom_image()
return (img,) + label
if map_fn: # fuse `map_fn` and `parse_fn`
def map_fn_(*args):
return map_fn(*parse_fn(*args))
else:
map_fn_ = parse_fn
dataset = memory_data_batch_dataset(memory_data,
batch_size,
drop_remainder=drop_remainder,
n_prefetch_batch=n_prefetch_batch,
filter_fn=filter_fn,
map_fn=map_fn_,
n_map_threads=n_map_threads,
filter_after_map=filter_after_map,
shuffle=shuffle,
shuffle_buffer_size=shuffle_buffer_size,
repeat=repeat)
return dataset
# 加载自定义数据集进TensorFlow的主要函数,drop_reminder参数是当数据集大小不能整除batch_size时是否丢掉余数部分
def make_anime_dataset(img_paths, batch_size, resize=64, drop_remainder=True, shuffle=True, repeat=1):
# @tf.function
def _map_fn(img): # 对图片数据进行归一化处理
img = tf.image.resize(img, [resize, resize])
# img = tf.image.random_crop(img,[resize, resize])
# img = tf.image.random_flip_left_right(img)
# img = tf.image.random_flip_up_down(img)
img = tf.clip_by_value(img, 0, 255)
img = img / 127.5 - 1 # -1~1
return img
dataset = disk_image_batch_dataset(img_paths,
batch_size,
drop_remainder=drop_remainder,
map_fn=_map_fn,
shuffle=shuffle,
repeat=repeat)
img_shape = (resize, resize, 3)
len_dataset = len(img_paths) // batch_size
return dataset, img_shape, len_dataset
'''
说下自己对代码的理解:
将图片路径转化为tensor,map函数中的第一个参数func函数负责将图片读进内存并讲图片数据归一化,此处这个func函数的调用使用的是
回调函数机制。数据集的批量大小以及drop_remainder均是通过dataset.batch这个api来实现和处理的。
粗浅的理解不知正确与否,若有大佬知道,恳请指点
'''
这个代码出自龙良曲老师的《深度学习与TensorFlow入门实战》GAN实战-3,不过现在B站已经将这个视频下架了(所以填转载都没有链接了,只能厚颜无耻的写成原创了),只能去某盘找了