代码如下:
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from PIL import Image
import numpy as np
import os
from pathlib import Path
from multiprocessing import Pool
import random
category_list = [
'Dicellograptus bispiralis',
'Dicellograptus caduceus',
'Dicellograptus divaricatus salopiensis',
'Dicellograptus smithi',
'Dicellograptus undatus',
'Dicranograptus irregularis',
'Dicranograptus sinensis',
'Didymograptus jiangxiensis',
'Didymograptus latus tholiformis',
'Didymograptus miserabilis'
]
datagen = ImageDataGenerator(
# rescale=1./255,
rotation_range=30,
width_shift_range=0.1,
height_shift_range=0.1,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
vertical_flip=True,
channel_shift_range=0.3,
fill_mode='nearest')
def generate_function(category):
target_num = 2000
org_dir = os.path.join(r'D:\set100_org\annotated_microscope_images\test1\train_images', category)
save_dir = org_dir.replace('set100_org', 'set100')
img_list = os.listdir(org_dir)
img_num = len(img_list)
target_num = target_num - len(os.listdir(save_dir))
for i in range(target_num):
max_random = img_num - 1
index = random.randint(0, max_random) # 随机数字区间, 选择一张被增强的图像
img_path = Path(org_dir) / img_list[index]
img_name = img_path.stem
# load_img返回一个PIL对象
img = load_img(str(img_path))
# 将PIL对象转换为np.ndarray对象
img_array = img_to_array(img)
# 转换为四维数组(batchsize,h,w,c)
img_array = img_array.reshape((1,) + img_array.shape)
# 归一化
img_array = img_array.astype('float32') / 255
# batch_size=1,输出一张图像
batch = datagen.flow(img_array, batch_size=1)
gen_image = batch.next()
# 将数组转为PIL.Image对象
gen_image = array_to_img(gen_image[0])
save_path = Path(save_dir) / (img_name + "_enhance" + str(i) + ".jpg")
# 如果图像名重名,则追加后缀
if save_path.exists():
save_path = Path(save_dir) / (img_name + "_enhance" + str(i) + "_" + str(i) + ".jpg")
print(img_path)
print(save_path)
gen_image.save(save_path)
if __name__ == '__main__':
# 这里用了多进程,同时增强10个类别
p = Pool(10)
for category in category_list:
p.apply_async(generate_function, args=(category,))
p.close()
p.join()
总体流程是:随机对org_dir文件夹中的图像进行数据增强,增强的图像保存在save_path中,增强的图像数目为target_num减去save_path中的图像数,文中设置的target_num=2000,即最终save_path中有2000张图像。
具体思路是:每次从org_dir中随机选择一张图像,读取为PIL对象并转为数组格式,然后通过datagen进行增强,再转换回PIL对象进行存储。看注释即可理解每一步的操作。
ImageDataGenerator的参数可查文档自行选择。