一、背景
实际业务中需要检测破裂和破损,但是缺乏对应的数据,那么该怎么办?可以利用现成的破裂mask和无破裂的数据,再借助已经打标好的bbox合成假的破损破裂数据。
二、合成代码
import cv2
import numpy as np
import albumentations as A
import os
import xml.etree.ElementTree as ET
def add_alpha_channel(img):
""" 为jpg图像添加alpha通道 """
b_channel, g_channel, r_channel = cv2.split(img) # 剥离jpg图像通道
alpha_channel = np.ones(b_channel.shape, dtype=b_channel.dtype) * 255 # 创建Alpha通道
img_new = cv2.merge((b_channel, g_channel, r_channel, alpha_channel)) # 融合通道
return img_new
def merge_img(jpg_img, png_img, y1, y2, x1, x2):
""" 将png透明图像与jpg图像叠加
y1,y2,x1,x2为叠加位置坐标值
"""
# 判断jpg图像是否已经为4通道
if jpg_img.shape[2] == 3:
jpg_img = add_alpha_channel(jpg_img)
'''
当叠加图像时,可能因为叠加位置设置不当,导致png图像的边界超过背景jpg图像,而程序报错
这里设定一系列叠加位置的限制,可以满足png图像超出jpg图像范围时,依然可以正常叠加
'''
yy1 = 0
yy2 = png_img.shape[0]
xx1 = 0
xx2 = png_img.shape[1]
if x1 < 0:
xx1 = -x1
x1 = 0
if y1 < 0:
yy1 = - y1
y1 = 0
if x2 > jpg_img.shape[1]:
xx2 = png_img.shape[1] - (x2 - jpg_img.shape[1])
x2 = jpg_img.shape[1]
if y2 > jpg_img.shape[0]:
yy2 = png_img.shape[0] - (y2 - jpg_img.shape[0])
y2 = jpg_img.shape[0]
# 获取要覆盖图像的alpha值,将像素值除以255,使值保持在0-1之间
alpha_png = png_img[yy1:yy2,xx1:xx2,3] / 255.0
#print (alpha_png.shape)
# for i in range(alpha_png.shape[1]):
# for j in range(alpha_png.shape[0]):
# if alpha_png[i,j] != 0:
# print (alpha_png[i,j])
# alpha_png[i,j] == 1
#alpha_png = 0.01
alpha_jpg = 1 - alpha_png
#print ('--------------------')
#print (alpha_jpg)
# for i in range(alpha_jpg.shape[1]):
# for j in range(alpha_jpg.shape[0]):
# if alpha_jpg[i,j] != 1.0:
# alpha_jpg[i,j] == 0
#print ( alpha_jpg[i,j])
#print (alpha_jpg)
# for j in range(200):
# if alpha_jpg[i,j] != 1.0:
# print ( alpha_jpg[i,j])
# 开始叠加
for c in range(0,3):
jpg_img[y1:y2, x1:x2, c] = ((alpha_jpg*jpg_img[y1:y2,x1:x2,c]) + (alpha_png*png_img[yy1:yy2,xx1:xx2,c]))
return jpg_img
if __name__ == '__main__':
## 10000次循环 每次随机选取一张背景图 随机选取一张破裂图 在一个范围内将破裂图进行图像增强 将破裂图融合到BOX内的
# 定义图像路径
import random
image_path = "./dataHxq8/images2/" #原始图像路径
xmlpath = "./dataHxq8/Annotations2/" ## 图像对应的标签路径
image_list = os.listdir(image_path)
mask_path = "./裂痕/" ###破裂mask图片路径 可以从网上下载或者购买 淘宝
mask_list = os.listdir(mask_path)
for i in range(2000):
random.shuffle(image_list)
image_name = image_list[0]
random.shuffle(mask_list)
mask_name = mask_list[0]
img_jpg_path = image_path + image_name # 读者可自行修改文件路径
xml_path = xmlpath + image_name.split('.')[0] + '.xml'
img_png_path = mask_path + mask_name # 读者可自行修改文件路径
print (image_name,mask_name)
# 读取图像
img_jpg = cv2.imread(img_jpg_path, cv2.IMREAD_UNCHANGED)
img_png = cv2.imread(img_png_path, cv2.IMREAD_UNCHANGED)
# img_png = cv2.resize(img_png,(200,200))
# print (img_png)
''' 适合自己扣下来的图
img_png = cv2.imread(img_png_path)
aug1 = A.Compose([
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.Resize(height=200, width=200, interpolation=3, always_apply=False, p=1),
A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=15, p=0.2),
A.Transpose(always_apply=False,p=0.5),
A.GridDistortion(num_steps=10,distort_limit=0.3,border_mode=4,always_apply=False,p=1),##网格失真
A.RandomBrightnessContrast(p=0.5), # 随机明亮对比度
A.RandomFog(fog_coef_lower=0.3,fog_coef_upper=0.6,alpha_coef=0.08,always_apply=False,p=0.3), ##随机雾化
A.CLAHE(clip_limit=4.0,tile_grid_size=(5,5),always_apply=False,p=0.5),###直方图均衡化 A.RandomBrightnessContrast(p=0.5), # 随机明亮对比度
A.RandomFog(fog_coef_lower=0.3,fog_coef_upper=0.6,alpha_coef=0.08,always_apply=False,p=0.3), ##随机雾化
A.CLAHE(clip_limit=4.0,tile_grid_size=(5,5),always_apply=False,p=0.5),###直方图均衡化
])
augmented = aug1(image=img_png)
image_aug = augmented['image']
tmp = cv2.cvtColor(image_aug, cv2.COLOR_BGR2GRAY)
_, alpha = cv2.threshold(tmp, 90, 255, cv2.THRESH_BINARY)
b, g, r = cv2.split(image_aug)
rgba = [b, g, r, alpha]
img_png = cv2.merge(rgba, 4)
#img_png = cv2.cvtColor(image_aug, cv2.COLOR_BGR2BGRA)
# print (augmented)
#cv2.imwrite('./augment.png',img_png,[int(cv2.IMWRITE_PNG_COMPRESSION), 0])
'''
gjList = []
tree = ET.parse(xml_path)
root = tree.getroot()
for obj in root.findall('object'):
#height, width = img_jpg.shape[:2]
# 获取目标框坐标
bndbox = obj.find('bndbox')
xmin = int(bndbox.find('xmin').text)
ymin = int(bndbox.find('ymin').text)
xmax = int(bndbox.find('xmax').text)
ymax = int(bndbox.find('ymax').text)
name = obj.find('name').text
box_height = abs(ymax - ymin)
box_width = abs(xmax - xmin)
if "gj" in name:
gjList.append([xmin,ymin,xmax,ymax])
##至此拿到了图中所有某类物体的box坐标
# 设置叠加位置坐标
if len(gjList) > 0:
for j in range(len(gjList)):
xmin,ymin,xmax,ymax = gjList[j]
width_ratio = 0.9 - random.random()*0.4 ### 生成0.6-1 之间的随机小数
height_ratio = 0.9 -random.random() *0.4
png_width = int((xmax - xmin) * width_ratio)
png_height = int((ymax - ymin) * height_ratio)
img_png = cv2.resize(img_png,(png_width,png_height))
x1 = xmin + int(random.random()*(1-width_ratio)*0.88*(xmax - xmin))
y1 = ymin + int(random.random()*(1-height_ratio)*0.88*(ymax - ymin))
x2 = x1 + img_png.shape[1] ### box 部分的宽度
y2 = y1 + img_png.shape[0] ## box 高度
# 开始叠加
res_img = merge_img(img_jpg, img_png, y1, y2, x1, x2)
# 显示结果图像
#cv2.imshow('result', res_img)
# 保存结果图像,读者可自行修改文件路径
cv2.imwrite('./gjps/'+'_'+str(i)+'_'+str(j)+'_'+image_name, res_img)
# 定义程序退出方式:鼠标点击显示图像的窗口后,按ESC键即可退出程序
# if cv2.waitKey(0) & 0xFF == 27:
# cv2.destroyAllWindows()