Tensorfow tf.data.Dataset 使用其他Python库进行数据增强

一般情况下,使用tf.data.Dataset进行一些数据增强操作时,只能使用tf提供的API,这样就存在很大的局限性。比如想对图像随机进行高斯模糊、运动模糊,按常规的Dataset写法代码如下,

import numpy as np
import cv2
import os
import tensorflow as tf
from tensorflow.python.data.experimental import AUTOTUNE

# 导入所有图片和其对应的标签
def load_files(dir):
    print("Loading files...")
    fileslist = []
    for path, _, files in os.walk(dir):
        for file in files:
            fileslist.append(os.path.join(path, file))

    return np.asarray(fileslist)

# 随机高斯模糊
def random_gaussian_blur(image):
    ksize_list = [3, 5, 7, 9]
    random_ksize = ksize_list[np.random.randint(0, 4)]    
    image = cv2.GaussianBlur(image, ksize=(random_ksize, random_ksize), sigmaX=0, sigmaY=0)
    return image

# 随机运动模糊
def random_motion_blur(image):
    image = np.array(image)
    angle = 45
    degree = np.random.randint(4, 12)
    M = cv2.getRotationMatrix2D((degree / 2, degree / 2), angle, 1)
    motion_blur_kernel = np.diag(np.ones(degree))
    motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M, (degree, degree))
    motion_blur_kernel = motion_blur_kernel / degree
    image = cv2.filter2D(image, -1, motion_blur_kernel)    
    cv2.normalize(image, image, 0, 255, cv2.NORM_MINMAX)
    image = np.array(image, dtype=np.uint8)
    return image

# 缩放
def resize(image):
    image = cv2.resize(image, (255, 255))
    return image

def preprocess(filename):
    src_image = cv2.imread(filename, cv2.IMREAD_COLOR)
    src_image = resize(src_image)
    dest_image = random_motion_blur(src_image)
    dest_image = random_gaussian_blur(dest_image)
    return [src_image, dest_image]


class MyDataset:
    def __init__(self, root_dir):  
        train_dir = os.path.join(root_dir, "train")
        valid_dir = os.path.join(root_dir, "valid")
        assert os.path.exists(train_dir) and os.path.exists(valid_dir)        

        self.train_filelist = load_files(train_dir)
        self.valid_filelist = load_files(valid_dir)
        assert len(self.train_filelist) > 0 and len(self.valid_filelist) > 0

    
    def create_dataset(self, subset="train", batch_size=16, repeat_count=None, random_transform=True):
        if subset=="train":
            filelist = self.train_filelist
        else:
            filelist = self.valid_filelist
        ds = tf.data.Dataset.from_tensor_slices(filelist)
        ds = ds.map(preprocess, num_parallel_calls=AUTOTUNE)
        ds = ds.batch(batch_size)
        ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
        return ds

if __name__ == "__main__":    
    
    dataset = MyDataset("./dataset/")
    train_ds = dataset.create_dataset(subset="train")

    for srcs, dests in train_ds:
        print(srcs.shape, dests.shape)

运行结果,

上面的代码运行会出错的,这时我们就要借用tf.py_function了,代码如下,

import numpy as np
import cv2
import os
import tensorflow as tf
from tensorflow.python.data.experimental import AUTOTUNE

# 导入所有图片和其对应的标签
def load_files(dir):
    print("Loading files...")
    fileslist = []
    for path, _, files in os.walk(dir):
        for file in files:
            fileslist.append(os.path.join(path, file))

    return np.asarray(fileslist)

# 随机高斯模糊
def random_gaussian_blur(image):
    ksize_list = [3, 5, 7, 9]
    random_ksize = ksize_list[np.random.randint(0, 4)]    
    image = cv2.GaussianBlur(image, ksize=(random_ksize, random_ksize), sigmaX=0, sigmaY=0)
    return image

# 随机运动模糊
def random_motion_blur(image):
    image = np.array(image)
    angle = 45
    degree = np.random.randint(4, 12)
    M = cv2.getRotationMatrix2D((degree / 2, degree / 2), angle, 1)
    motion_blur_kernel = np.diag(np.ones(degree))
    motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M, (degree, degree))
    motion_blur_kernel = motion_blur_kernel / degree
    image = cv2.filter2D(image, -1, motion_blur_kernel)    
    cv2.normalize(image, image, 0, 255, cv2.NORM_MINMAX)
    image = np.array(image, dtype=np.uint8)
    return image

# 缩放
def resize(image):
    image = cv2.resize(image, (255, 255))
    return image

def preprocess(filename):
    filename = filename.numpy().decode("utf-8")
    src_image = cv2.imread(filename, cv2.IMREAD_COLOR)
    src_image = resize(src_image)
    dest_image = random_motion_blur(src_image)
    dest_image = random_gaussian_blur(dest_image)
    return [src_image, dest_image]

def parse_func(filename):
    [src_image, dest_image] = tf.py_function(preprocess, [filename], [tf.float32, tf.float32])
    return src_image, dest_image


class MyDataset:
    def __init__(self, root_dir):  
        train_dir = os.path.join(root_dir, "train")
        valid_dir = os.path.join(root_dir, "valid")
        assert os.path.exists(train_dir) and os.path.exists(valid_dir)        

        self.train_filelist = load_files(train_dir)
        self.valid_filelist = load_files(valid_dir)
        assert len(self.train_filelist) > 0 and len(self.valid_filelist) > 0

    
    def create_dataset(self, subset="train", batch_size=16, repeat_count=None, random_transform=True):
        if subset=="train":
            filelist = self.train_filelist
        else:
            filelist = self.valid_filelist
        ds = tf.data.Dataset.from_tensor_slices(filelist)
        ds = ds.map(parse_func, num_parallel_calls=AUTOTUNE)
        ds = ds.batch(batch_size)
        ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
        return ds

if __name__ == "__main__":    
    
    dataset = MyDataset("./dataset/")
    train_ds = dataset.create_dataset(subset="train")

    for srcs, dests in train_ds:
        print(srcs.shape, dests.shape)

运行结果,

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值