一般情况下,使用tf.data.Dataset进行一些数据增强操作时,只能使用tf提供的API,这样就存在很大的局限性。比如想对图像随机进行高斯模糊、运动模糊,按常规的Dataset写法代码如下,
import numpy as np
import cv2
import os
import tensorflow as tf
from tensorflow.python.data.experimental import AUTOTUNE
# 导入所有图片和其对应的标签
def load_files(dir):
print("Loading files...")
fileslist = []
for path, _, files in os.walk(dir):
for file in files:
fileslist.append(os.path.join(path, file))
return np.asarray(fileslist)
# 随机高斯模糊
def random_gaussian_blur(image):
ksize_list = [3, 5, 7, 9]
random_ksize = ksize_list[np.random.randint(0, 4)]
image = cv2.GaussianBlur(image, ksize=(random_ksize, random_ksize), sigmaX=0, sigmaY=0)
return image
# 随机运动模糊
def random_motion_blur(image):
image = np.array(image)
angle = 45
degree = np.random.randint(4, 12)
M = cv2.getRotationMatrix2D((degree / 2, degree / 2), angle, 1)
motion_blur_kernel = np.diag(np.ones(degree))
motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M, (degree, degree))
motion_blur_kernel = motion_blur_kernel / degree
image = cv2.filter2D(image, -1, motion_blur_kernel)
cv2.normalize(image, image, 0, 255, cv2.NORM_MINMAX)
image = np.array(image, dtype=np.uint8)
return image
# 缩放
def resize(image):
image = cv2.resize(image, (255, 255))
return image
def preprocess(filename):
src_image = cv2.imread(filename, cv2.IMREAD_COLOR)
src_image = resize(src_image)
dest_image = random_motion_blur(src_image)
dest_image = random_gaussian_blur(dest_image)
return [src_image, dest_image]
class MyDataset:
def __init__(self, root_dir):
train_dir = os.path.join(root_dir, "train")
valid_dir = os.path.join(root_dir, "valid")
assert os.path.exists(train_dir) and os.path.exists(valid_dir)
self.train_filelist = load_files(train_dir)
self.valid_filelist = load_files(valid_dir)
assert len(self.train_filelist) > 0 and len(self.valid_filelist) > 0
def create_dataset(self, subset="train", batch_size=16, repeat_count=None, random_transform=True):
if subset=="train":
filelist = self.train_filelist
else:
filelist = self.valid_filelist
ds = tf.data.Dataset.from_tensor_slices(filelist)
ds = ds.map(preprocess, num_parallel_calls=AUTOTUNE)
ds = ds.batch(batch_size)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
return ds
if __name__ == "__main__":
dataset = MyDataset("./dataset/")
train_ds = dataset.create_dataset(subset="train")
for srcs, dests in train_ds:
print(srcs.shape, dests.shape)
运行结果,
上面的代码运行会出错的,这时我们就要借用tf.py_function了,代码如下,
import numpy as np
import cv2
import os
import tensorflow as tf
from tensorflow.python.data.experimental import AUTOTUNE
# 导入所有图片和其对应的标签
def load_files(dir):
print("Loading files...")
fileslist = []
for path, _, files in os.walk(dir):
for file in files:
fileslist.append(os.path.join(path, file))
return np.asarray(fileslist)
# 随机高斯模糊
def random_gaussian_blur(image):
ksize_list = [3, 5, 7, 9]
random_ksize = ksize_list[np.random.randint(0, 4)]
image = cv2.GaussianBlur(image, ksize=(random_ksize, random_ksize), sigmaX=0, sigmaY=0)
return image
# 随机运动模糊
def random_motion_blur(image):
image = np.array(image)
angle = 45
degree = np.random.randint(4, 12)
M = cv2.getRotationMatrix2D((degree / 2, degree / 2), angle, 1)
motion_blur_kernel = np.diag(np.ones(degree))
motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M, (degree, degree))
motion_blur_kernel = motion_blur_kernel / degree
image = cv2.filter2D(image, -1, motion_blur_kernel)
cv2.normalize(image, image, 0, 255, cv2.NORM_MINMAX)
image = np.array(image, dtype=np.uint8)
return image
# 缩放
def resize(image):
image = cv2.resize(image, (255, 255))
return image
def preprocess(filename):
filename = filename.numpy().decode("utf-8")
src_image = cv2.imread(filename, cv2.IMREAD_COLOR)
src_image = resize(src_image)
dest_image = random_motion_blur(src_image)
dest_image = random_gaussian_blur(dest_image)
return [src_image, dest_image]
def parse_func(filename):
[src_image, dest_image] = tf.py_function(preprocess, [filename], [tf.float32, tf.float32])
return src_image, dest_image
class MyDataset:
def __init__(self, root_dir):
train_dir = os.path.join(root_dir, "train")
valid_dir = os.path.join(root_dir, "valid")
assert os.path.exists(train_dir) and os.path.exists(valid_dir)
self.train_filelist = load_files(train_dir)
self.valid_filelist = load_files(valid_dir)
assert len(self.train_filelist) > 0 and len(self.valid_filelist) > 0
def create_dataset(self, subset="train", batch_size=16, repeat_count=None, random_transform=True):
if subset=="train":
filelist = self.train_filelist
else:
filelist = self.valid_filelist
ds = tf.data.Dataset.from_tensor_slices(filelist)
ds = ds.map(parse_func, num_parallel_calls=AUTOTUNE)
ds = ds.batch(batch_size)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
return ds
if __name__ == "__main__":
dataset = MyDataset("./dataset/")
train_ds = dataset.create_dataset(subset="train")
for srcs, dests in train_ds:
print(srcs.shape, dests.shape)
运行结果,