当数据集较小,或者网络过浅容易过拟合时,需要对数据集进行数据增强以丰富样本数据,让一张图片变换多种形态,增强网络的学习能力。
以下提供两种可实现代码,对数据进行批量随机变换,只需要修改输入和输出文件夹路径,即可批量生成,一张图片可扩增8倍。(默认已安装必要的包)
原始图片
实现一:随机进行颜色增强,降低,修改对比度,锐度,随机水平翻转,竖直翻转,组合翻转。
import os
import random
import cv2
from PIL import Image, ImageEnhance
def operate(currentPath, filename, targetPath):
# 读取图像
image = Image.open(currentPath)
image_cv = cv2.imread(currentPath)
# 增强亮度
enh_bri = ImageEnhance.Brightness(image)
brightness_factor = random.uniform(0.8, 1.2) # 随机生成0.8到1.2之间的数值
image_brightened = enh_bri.enhance(brightness_factor)
directory = targetPath + 'bright/'
if not os.path.exists(directory):
os.makedirs(directory)
image_brightened.save(targetPath + 'bright/' + filename) # 保存
# 降低亮度
enh_bri_low = ImageEnhance.Brightness(image)
brightness_factor = random.uniform(0.5, 0.9) # 随机生成0.5到0.9之间的数值
image_brightened_low = enh_bri_low.enhance(brightness_factor)
directory = targetPath + 'low_bright/'
if not os.path.exists(directory):
os.makedirs(directory)
image_brightened_low.save(targetPath + 'low_bright/' + filename)
# 改变色度
enh_col = ImageEnhance.Color(image)
color_factor = random.uniform(0.5, 1.5) # 随机生成0.5到1.5之间的数值
image_colored = enh_col.enhance(color_factor)
directory = targetPath + 'color/'
if not os.path.exists(directory):
os.makedirs(directory)
image_colored.save(targetPath + 'color/' + filename)
# 改变对比度
enh_con = ImageEnhance.Contrast(image)
contrast_factor = random.uniform(0.5, 1.5) # 随机生成0.5到1.5之间的数值
image_contrasted = enh_con.enhance(contrast_factor)
directory = targetPath + 'cont/'
if not os.path.exists(directory):
os.makedirs(directory)
image_contrasted.save(targetPath + 'cont/' + filename)
# 改变锐度
enh_sha = ImageEnhance.Sharpness(image)
sharpness_factor = random.uniform(1.5, 3.5) # 随机生成1.5到3.5之间的数值,你可以根据需要调整这个范围
image_sharp = enh_sha.enhance(sharpness_factor)
directory = targetPath + 'sharp/'
if not os.path.exists(directory):
os.makedirs(directory)
image_sharp.save(targetPath + 'sharp/' + filename)
# 水平翻转
if random.random() < 0.5: # 以50%的概率进行水平翻转
image_flipped_h = image.transpose(Image.FLIP_LEFT_RIGHT)
directory = targetPath + 'flip_h/'
if not os.path.exists(directory):
os.makedirs(directory)
image_flipped_h.save(targetPath + 'flip_h/' + filename)
# 竖直翻转
if random.random() < 0.5: # 以50%的概率进行竖直翻转
image_flipped_v = image.transpose(Image.FLIP_TOP_BOTTOM)
directory = targetPath + 'flip_v/'
if not os.path.exists(directory):
os.makedirs(directory)
image_flipped_v.save(targetPath + 'flip_v/' + filename)
if random.random() < 0.5: # 以50%的概率进行水平竖直翻转
image_flipped_h = image.transpose(Image.FLIP_LEFT_RIGHT)
image_flipped_v = image_flipped_h.transpose(Image.FLIP_TOP_BOTTOM)
directory = targetPath + 'flip_hv/'
if not os.path.exists(directory):
os.makedirs(directory)
image_flipped_v.save(targetPath + 'flip_hv/' + filename)
imgdir = 'your_imagepath/file/'
for parent, dirnames, filenames in os.walk(imgdir):
for filename in filenames:
print('filename is: ' + filename)
# 把文件名添加到一起后输出
imgPath = os.path.join(parent, filename)
# 保存处理后的图片的目标文件夹
outPath = 'save_path/file/'
# 进行处理
operate(imgPath, filename, outPath)
实现二:随机水平竖直翻转,随机角度旋转,颜色变换,随机裁剪
import os
import random
import cv2
import numpy as np
from PIL import Image
from torchvision import transforms as tfs
import matplotlib.pyplot as plt
def coms(input):
size=[200,250,300] #随机裁剪尺寸
selected_size = random.randint(1, 3)
im_l = tfs.Compose([
tfs.RandomVerticalFlip(),#随机竖直翻转
tfs.RandomHorizontalFlip(),#随机水平翻转
#tfs.RandomRotation(45), #随机旋转 -45 ~45
#tfs.Resize((100, 200)) 压缩
tfs.RandomCrop((size[selected_size-1],size[selected_size-1])),
tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5), #(亮度,对比度,颜色)
#tfs.CenterCrop(300),#中心裁剪
])
out = im_l(input)
return out
if __name__ == '__main__':
im = Image.open('your_imagepath/input.jpg')
output_folder = 'save_path/file/'
nrows = 3
ncols = 3
figsize = (8, 8)
_, figs = plt.subplots(nrows, ncols, figsize=figsize)# 创建子图画布
for i in range(nrows):
for j in range(ncols):
# 对原始图像进行变换
transformed_image = coms(im)
transformed_image_np = np.array(transformed_image)
cv2.imwrite(os.path.join(output_folder, f"trans_img_{i}_{j}.png"), transformed_image_np)
# 显示变换后的图像
figs[i][j].imshow(transformed_image)
figs[i][j].axes.get_xaxis().set_visible(False)
figs[i][j].axes.get_yaxis().set_visible(False)
plt.show()