模型训练好后发现数据集采集的图片样本中,光照和实际应用环境不一样导致误识率很高,并且色彩也不同,比如数据集中的图片偏白,实际应用环境偏黄偏绿,这样也有一定影响,因此这里提供了一些图片增强的方法。
注意是图像增强不是数据增强,两者不一样,如何区分请移步:
几种数据增强:Mixup,Cutout,CutMix 和yolov4中的 Mosaic_RayChiu757374816的博客-CSDN博客
opencv-python 详解直方图均衡(一)《图像增强、灰度变换和直方图均衡化关系》_RayChiu757374816的博客-CSDN博客_opencv-python灰度直方图
下面给出整理好的图像增强几种方式:
import cv2
import numpy as np
import os
def getColorImg(alpha,beta,img_path,img_write_path):
img = cv2.imread(img_path)
colored_img = np.uint8(np.clip((alpha * img + beta), 0, 255))
cv2.imwrite(img_write_path,colored_img)
def color(alpha,beta,img_dir,img_write_dir):
if not os.path.exists(img_write_dir):
os.makedirs(img_write_dir)
img_names=os.listdir(img_dir)
for img_name in img_names:
img_path=os.path.join(img_dir,img_name)
img_write_path=os.path.join(img_write_dir,img_name[:-4]+'color'+str(int(alpha*10))+'.jpg')
getColorImg(alpha,beta,img_path,img_write_path)
def claheMethod(img_dir,img_write_dir):
if not os.path.exists(img_write_dir):
os.makedirs(img_write_dir)
img_names=os.listdir(img_dir)
for img_name in img_names:
img_path=os.path.join(img_dir,img_name)
img_write_path=os.path.join(img_write_dir,img_name[:-4]+'clahe'+'.jpg')
img = cv2.imread(img_path)
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)#彩色图要拆分三个通道分别做均衡化,否则像我这里一样转为灰度图
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) # 自适应均衡化,参数可选
cl1 = clahe.apply(hsv)
#测试加了滤波没能让边缘清晰
#cl1MedianBlur = cv2.medianBlur(cl1, 1)
# cl1GaussianBlur = cv2.GaussianBlur(cl1, (1, 1), 0)
cv2.imwrite(img_write_path, cl1)
def adjust_gamma(img_dir,img_write_dir,gamma = 1.0):
if not os.path.exists(img_write_dir):
os.makedirs(img_write_dir)
img_names=os.listdir(img_dir)
for img_name in img_names:
img_path=os.path.join(img_dir,img_name)
img_write_path=os.path.join(img_write_dir,img_name[:-4]+'adjust_gamma'+'.jpg')
image = cv2.imread(img_path)
invGamma = 1.0 / gamma
table = np.array([((i / 255.0) ** invGamma) * 255
for i in np.arange(0, 256)]).astype("uint8")
cv2.imwrite(img_write_path, cv2.LUT(image, table))
def interMethod(img_dir,img_write_dir):
if not os.path.exists(img_write_dir):
os.makedirs(img_write_dir)
img_names=os.listdir(img_dir)
for img_name in img_names:
img_path=os.path.join(img_dir,img_name)
img = cv2.imread(img_path)
height, width = img.shape[:2]
new_height, new_width = int(height * 2), int(width * 2)
# 双三次
cubic_img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_CUBIC)
# 双线性
linear_img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_LINEAR)
# 最邻近
nearest_img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_NEAREST)
img_write_path=os.path.join(img_write_dir,img_name[:-4]+'cubic'+'.jpg')
cv2.imwrite(img_write_path, cubic_img)
img_write_path = os.path.join(img_write_dir, img_name[:-4] + 'linear' + '.jpg')
cv2.imwrite(img_write_path, linear_img)
img_write_path = os.path.join(img_write_dir, img_name[:-4] + 'nearest' + '.jpg')
cv2.imwrite(img_write_path, nearest_img)
# 原始图片地址
img_dir = r'E:\Users\raychiu\Desktop\22\55'
# 图像增强后图片地址
img_write_dir = r'E:\Users\raychiu\Desktop\22\44'
# #第一步亮度、对比度增强
# alphas = [0.3, 0.5, 1.2]
# beta = 10
# for alpha in alphas:
# color(alpha, beta, img_dir, img_write_dir)
#
# # #第二步自适应直方图均衡化,减少色彩不同和不均衡影响
# claheMethod(img_dir, img_write_dir)
#
# #第三步伽马矫正,减少光照影响
# gamma = 2.2
# adjust_gamma(img_dir, img_write_dir,gamma=gamma)
#第四步线性插值 resize图像
interMethod(img_dir, img_write_dir)
另外我曾经尝试将图像增强逻辑植入训练代码中,发现有些难度,因此可以通过以上代码在训练前先处理好图片,然后训练,并且使用模型检测目标之前也要做相同的图像增强处理。
检测逻辑如何加相同的处理呢,假如我只加了直方图(直方图均衡化的方式实现亮度提升,更有利于边缘识别与物体识别模型的训练),那么以yolov5为例,detect.py中遍历回去dataset这里,115行:
im这个对象就是三通道的图片,我们对它做处理即可,以下是模拟加直方图:
......
for path, img, im0s, vid_cap, s in dataset:
#-----------------图片预处理开始---------------------
print("img.shape:", img.shape)#(3, 480, 640)
img = img.transpose((1, 2, 0)) #转置一下,把原来通道维度由第一维度放到最后一个维度,适配opencv (480, 640, 3)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # 开始直方图处理,彩色图要拆分三个通道分别做均衡化,否则像我这里一样转为灰度图
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) # 自适应均衡化,参数可选
bgrimg = clahe.apply(gray)
image = np.expand_dims(bgrimg, axis=2)#单通道的灰度图恢复为三通道图片
im = np.concatenate((image, image, image), axis=-1)
im = im.transpose((2, 0, 1))#转置回来
print("im.shape:", im.shape)#(3, 480, 640)
#-----------------图片预处理结束---------------------
t1 = time_sync()
im = torch.from_numpy(im).to(device)
im = im.half() if half else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
......