【TensorFlow-windows】keras接口——ImageDataGenerator裁剪

前言

Keras中有一个图像数据处理器ImageDataGenerator,能够很方便地进行数据增强,并且从文件中批量加载图片,避免数据集过大时,一下子加载进内存会崩掉。但是从官方文档发现,并没有一个比较重要的图像增强方式:随机裁剪,本博客就是记录一下如何在对ImageDataGenerator中生成的batch做图像裁剪

国际惯例,参考博客:

官方ImageDataGenerator文档

Keras 在fit_generator训练方式中加入图像random_crop

Extending Keras’ ImageDataGenerator to Support Random Cropping

how to use fit_generator with multiple image inputs

第二个博客比较全,第三个博客只介绍了分类数据的增强,如果是图像分割或者超分辨率,输出仍是一张图像,所以涉及到对imagemask进行同步增强

代码

先介绍一下数据集目录结构:

test文件夹下,分别有GTNGT两个文件夹,每个文件夹存储的都是bmp图像文件

其次需要注意,从ImageDataGenerator中取数据用的是next(generator)函数

  • 载入相关包

    from keras_preprocessing.image import ImageDataGenerator
    import matplotlib.pyplot as plt
    import numpy as np
    
  • 先使用自带的ImageDataGenerator配合flow_from_director读取数据
    创建生成器

    train_img_datagen=ImageDataGenerator()#各种预处理
    train_mask_datagen=ImageDataGenerator()#各种预处理
    

    读取文件

    seed=2 #图像会随机打乱即shuffle,但是输入和输出的打乱顺序必须一样
    batch_size=2
    target_size=(1080,1920)
    train_img_gen=train_img_datagen.flow_from_directory('./test',classes=['NGT'],
                                                        class_mode=None,
                                                        batch_size=batch_size,
                                                        target_size=target_size,
                                                        shuffle=True,
                                                        seed=seed,
                                                        interpolation='bicubic')
    train_mask_gen=train_img_datagen.flow_from_directory('./test',
                                                         classes=['GT'],
                                                         class_mode=None,
                                                         batch_size=batch_size,
                                                         target_size=target_size,
                                                         shuffle=True,
                                                         seed=seed,
                                                         interpolation='bicubic')
    

    封装打包

    train_generator=zip(train_img_gen,train_mask_gen)
    
  • 定义裁剪器,裁剪图像和对应的mask:

    def crop_generator(batch_gen,crop_size=(270,480)):
        while True:
            batch_x,batch_y=next(batch_gen)
            crops_img=np.zeros((batch_x.shape[0],crop_size[0],crop_size[1],3))
            crops_mask=np.zeros((batch_y.shape[0],crop_size[0],crop_size[1],3))
            height,width=batch_x.shape[1],batch_x.shape[2]
            for i in range(batch_x.shape[0]):
                #裁剪图像
                x=np.random.randint(0,height-crop_size[0]+1)
                y=np.random.randint(0,width-crop_size[1]+1)
                crops_img[i]=batch_x[i,x:x+crop_size[0],y:y+crop_size[1]]
                crops_mask[i]=batch_y[i,x:x+crop_size[0],y:y+crop_size[1]]
            yield (crops_img,crops_mask)
    
  • 使用裁剪器对Generator进行裁剪

    train_crops=crop_generator(train_generator)
    

可视化:

img,mask=next(train_crops)
print(img.shape)
plt.subplot(2,1,1)
plt.imshow(img[0]/255)
plt.subplot(2,1,2)
plt.imshow(mask[0]/255)

在这里插入图片描述

后记

记住要用while(True)死循环,并且yieldwhile循环内部,和for循环外部,代表每个批次

代码:
链接:https://pan.baidu.com/s/1UNZLke5kygBFHJ8iR8wV2A
提取码:e51e

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

风翼冰舟

额~~~CSDN还能打赏了

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值