在实际生活中,有时候负样本并不是那么好收集。(这里负样本指的是图像中某部分特征区域,形状与正样本完全不同,例如隔离开关的状态)此时,需要用极少的负样本去生成大量的负样本,就需要将实际的负样本截取处理之后,PS到正样本图片上特征区域替代原有的特征。生成负样本数据。
以下为代码解析:
# -*-coding = utf-8 -*-
# @Time: 2021/12/6 18:38
# @Author: xiaoya
# @File: peizhi.py
# @Software: PyCharm
import cv2
import random
import numpy as np
import os
import time
import albumentations as A
def rand(a=0, b=1):
return np.random.rand() * (b - a) + a
# 图像处理操作,可根据需要选择使用
transform = A.Compose([
#A.RandomCrop(width=256, height=256),
A.HorizontalFlip(p=0.5),
#A.RandomBrightnessContrast(brightness_limit=(0.1,0.5), contrast_limit=(0.1,0.5),p=0.2),
#A.JpegCompression(quality_lower=19, quality_upper=20, p=1),
#A.Blur(blur_limit=(15, 15), p=1),
#A.RGBShift(r_shift_limit=(10,50), g_shift_limit=(10,50), b_shift_limit=(10,50), p=0.3),
A.MultiplicativeNoise(multiplier=(0.3, 1.1), p=0.5),
A.ColorJitter(brightness=(0.3,0.7),contrast=(0.3,0.7),saturation=(0.3,0.7),hue=(-0.5,0.5),p=0.3),
#A.Posterize(num_bits=2, p=1)
#A.MotionBlur(blur_limit=(5,30), p=0.3),
#A.RandomGamma(p=1),
#A.Downscale(scale_min=0.25, scale_max=0.75, p=1),
#A.RandomShadow(shadow_roi=(0, 0, 1, 1), p=0.2),
])
# 读取label标注txt文件的标签
def box_from_txt(txt, img_h, img_w):
boxes = []
with open(txt, 'r') as f:
for line in f.readlines():
if line[0] == '1':
#print(line)
yolo_datas = line.strip().split(' ')
label = int(float(yolo_datas[0].strip()))
#print(label)
center_x = round(float(str(yolo_datas[1]).strip()) * img_w)
#print(center_x)
center_y = round(float(str(yolo_datas[2]).strip()) * img_h)
bbox_width = round(float(str(yolo_datas[3]).strip()) * img_w)
bbox_height = round(float(str(yolo_datas[4]).strip()) * img_h)
xmin = (int(center_x - bbox_width / 2))
ymin = (int(center_y - bbox_height / 2))
xmax = (int(center_x + bbox_width / 2))
ymax = (int(center_y + bbox_height / 2))
loc = [label, xmin, ymin, xmax, ymax]
boxes.append(loc)
return boxes
# 负样本进行不同尺度大小的调整,可根据需要更改和增加函数式,这里作者形态有4种形式
def calculate_center(roi,roi_height,roi_width,scale,scale1):
center1 = (int(roi[1] + roi_width / 2), int(roi[2] + roi_height * scale / 2))
center2 = (int(roi[1] + roi_width*scale1 / 2), int(roi[2] - roi_height*(scale-1) + (roi_height +roi_height*(scale-1))/2))
center3 = (int(roi[1] + roi_width * scale1 / 2), int(roi[2] + (roi_height + roi_height * (scale - 1)) / 2))
center4 = (int(roi[1] + roi_width / 2), int(roi[2] - roi_height * (scale - 1) + (roi_height + roi_height * (scale - 1)) / 2))
center = [center1, center2, center3,center4]
return center
imgfile = r'D:/images/' # 图片存放路径
txtfile = r'D:/txt/' # label标注文件
savepath = r'D:/out/' # 新生成的负样本保存路径
img_list = os.listdir(imgfile)
txt_list = os.listdir(txtfile)
# mapfiles = open(savepath + 'out_mapfiles.txt', 'w', encoding='UTF-8')
'''
定义为列表,前两个数字为索引值,即在图片存放路径的索引范围;
中间两个数字为新标签调整的幅度,可自行调整;
最后一个参数为(不同种类或者大小)截取好的负样本存放路径
'''
style1 = [0,200,1.0,1.0,'D:/win/1/']
style2 = [200,400,1.0,1.2,'D:/win/2/']
style3 = [400,548,1.0,1.5,'D:/win/3/']
style4 = [548,648,1.0,1.3,'D:/win/4/']
style = [style1,style2,style3,style4]
def run(style):
for i in range(0,4):
style = style[i]
start = style[0]
end = style[1]
# print(start)
# print(end)
for idx in range(0, 200):
index_img = np.random.randint(0, len(img_list) - 1)
imgpath = imgfile + img_list[index_img]
annpath = txtfile + txt_list[index_img]
name_img = imgpath.split('/')[-1].split('.')[0]
name_ann = annpath.split('/')[-1].split('.')[0]
# print(name_img, name_ann)
# 检查原始文件
if name_img != name_ann:
raise ValueError("txt_file is not exist for image")
print(idx + 1, annpath)
img = cv2.imdecode(np.fromfile(imgpath, dtype=np.uint8), -1)
img_h, img_w = img.shape[0], img.shape[1] # 获取图像高和宽
boxes = box_from_txt(annpath, img_h, img_w)
out_file = open(savepath + 'out_%04d.txt' % (idx), 'w', encoding='UTF-8')
mapfiles.write(name_ann + ' out_%04d.txt\n' % (idx))
for n in range(12): # num = np.random.randint(2, 13)
flag = np.random.randint(0, 2)
if flag == 0:
box = boxes[n]
old_label = [box[0], float(((box[1] + box[3]) / 2) / img_w),
float(((box[2] + box[4]) / 2) / img_h),
float((box[3] - box[1]) / img_w), float((box[4] - box[2]) / img_h)]
out_file.write(" ".join([str(a) for a in old_label]) + '\n')
else:
roi = boxes[n]
roi_height = int(roi[4] - roi[2])
roi_width = int(roi[3] - roi[1])
winfile = style[4]
win_list = os.listdir(winfile)
index_win = np.random.randint(0, len(win_list))
winpath = winfile + win_list[index_win]
window = cv2.imdecode(np.fromfile(winpath, dtype=np.uint8), -1)
window = cv2.cvtColor(window, cv2.COLOR_BGR2RGB)
transformed = transform(image=window)
transformed_image = transformed["image"]
transformed_image = cv2.cvtColor(transformed_image, cv2.COLOR_RGB2BGR)
scale = style[2]
scale1 = style[3]
window_resize = cv2.resize(transformed_image, (int(roi_width * scale1), int(roi_height * scale)))
center = calculate_center(roi, roi_height, roi_width, scale, scale1)
new_label = [0, (float(center[i][0] / img_w)), (float(center[i][1] / img_h)),
(float(roi_width * scale1 / img_w)), (float(roi_height * scale / img_h))]
src_mask = 255 * np.ones(window_resize.shape, window_resize.dtype)
ps_img = img.copy()
img = cv2.seamlessClone(window_resize, img, src_mask, center[i], cv2.NORMAL_CLONE)
out_file.write(" ".join([str(a) for a in new_label]) + '\n')
out_file.close()
cv2.imwrite(savepath + 'out_%04d.jpg' % (idx), img)
if __name__ == "__main__":
run(style);
欢迎各位小伙伴交流哦!
有不懂的可留言~