记录下自己使用skimage做数据增强是所遇到的一些情况。
实现的功能:对图片实现调明暗,模糊,噪音,裁剪,缩放,翻转及其组合,并生成对应使用labelimg生成的标注xml文件。
遇到的问题:1生成高斯噪音时,图像上有很明显的杂色聚集
原因:像素值在加高斯噪音 img+random.gauss(mean,sigma)时,部分像素小于0,大于255.
解决办法:把图像先转成 int16 ,在做截断 img[img>255]=255,img[img<0] = 0.
问题2:有部分图像保存失败。
原因:再调用函数fiter.gauss()时,数据转成了float,skimage中float限制在(-1,1),部分图片超界导致不成功。
解决办法:数据截断 img[img>1]=1,img[img<-1] = -1
问题3:保存的图片成纯白或纯黑
原因:在组合增强方式的时候,有函数返回dtype为unit8,float64,进行了错误的截断。
解决办法:打印出每个返回图像数据的dtype,对应做截断。
"""
数据增强实现功能:定义生成图片类,实现图片数据增强,对应定义一个生成XML类,实现labelimg标注的图片XML生成,使用时设置
aug_list参数选择要进行的增强内容,设置原图,保存图片,原XML,保存XML路径
"""
import skimage
from skimage import io, transform, filters
from skimage import color, data_dir, data
from skimage import exposure, img_as_float,img_as_ubyte,img_as_int,dtype_limits
import random
import cv2
import numpy as np
import os
import xml.dom.minidom
class ImgAugmentation:
def __init__(self, imgpath, img_save_path):
self.imgpath = imgpath
self.img_save_path = img_save_path
def gen(self, auglist):
if 1 in auglist:
self.img_gen(1,["bnh_","bnl_"])#调明暗self.brightness_func
if 2 in auglist:
self.img_gen(2,["ns_","gns_"])#调噪音self.noise_func
if 3 in auglist:
self.img_gen(3,["fil1_","fil2_"])#滤波self.filter_func
if 4 in auglist:
self.img_gen(4,["add12_","add13_","add23_","add123_"])#组合self.add_func
if 5 in auglist:
self.img_gen(5,["fliph_","fliph12_","fliph13_","fliph23_","fliph123_"])#水平翻转self.fliph_func
if 6 in auglist:
self.img_gen(6,["flipv_","flipv12_","flipv13_","flipv23_","flipv123_"])#垂直翻转self.flipv_func
if 7 in auglist:
self.img_gen(7,["rsl_","rsl12_","rsl13_","rsl23_","rsl123_"])#缩放self.rescale_func
if 8 in auglist:
self.img_gen(8,["crop_","crop12_","crop13_","crop23_","crop123_"])#裁剪self.crop_func
def img_gen(self,load_func_idx,str_lists):
strs = self.imgpath + "\\*.jpg"
if load_func_idx ==1:
coll = io.ImageCollection(strs, load_func=self.brightness_func)
if load_func_idx ==2:
coll = io.ImageCollection(strs, load_func=self.noise_func)
if load_func_idx ==3:
coll = io.ImageCollection(strs, load_func=self.filter_func)
if load_func_idx ==4:
coll = io.ImageCollection(strs, load_func=self.add_func)
if load_func_idx ==5:
coll = io.ImageCollection(strs, load_func=self.fliph_func)
if load_func_idx ==6:
coll = io.ImageCollection(strs, load_func=self.flipv_func)
if load_func_idx ==7:
coll = io.ImageCollection(strs, load_func=self.rescale_func)
if load_func_idx ==8:
coll = io.ImageCollection(strs, load_func=self.crop_func)
for i in range(len(coll)):
imgpath= coll[i][0]
for j in range(len(str_lists)):
try:
imgname = os.path.split(imgpath)[1]
io.imsave(os.path.join(self.img_save_path, str_lists[j] + imgname),coll[i][j+1] )
print(os.path.join(self.img_save_path, str_lists[j] + imgname), "is save")
except Exception as e :
print(os.path.join(self.img_save_path, str_lists[j] + imgname)+ " is error to save")
print(e)
def brightness_func(self, imgfile):
img = io.imread(imgfile)
img1 = exposure.adjust_gamma(img, 0.5)
img2 = exposure.adjust_gamma(img, 1.5)
return imgfile, img1, img2
def noise_func(self, imgfile):
img_noise = io.imread(imgfile)
img_gaussnoise = io.imread(imgfile).astype(np.int16)
height, width, depth = img_noise.shape
num = int(width * height * 0.005)
for i in range(num):
x = np.random.randint(0, width - 1)
y = np.random.randint(0, height - 1)
if random.randint(0, 1) == 0:
img_noise[y, x] = 255
else:
img_noise[y, x] = 0
img_gaussnoise = img_gaussnoise+np.random.normal(0,10,[height,width,depth])
img_gaussnoise[img_gaussnoise > 255] = 255
img_gaussnoise[img_gaussnoise < 0] = 0
img_gaussnoise = img_gaussnoise.astype(np.uint8)
return imgfile,img_noise,img_gaussnoise
def filter_func(self,imgfile):
img = io.imread(imgfile)
img1 =filters.gaussian(img,1)
img1[img1<-1] =-1
img1[img1>1] =1
img2 = filters.gaussian(img,1.5)
img2[img2 < -1] = -1
img2[img2 > 1] = 1
return imgfile,img1,img2
def add_func(self,imgfile):
#生成随机明暗,噪音的图片
img = io.imread(imgfile)
img_add12, img_add13, img_add23, img_add123=self.img_add_b_n_f(img)
return imgfile,img_add12,img_add13,img_add23,img_add123
def fliph_func(self,imgfile):
img_fliph = io.imread(imgfile)
img_fliph = img_fliph[:,::-1,:]
img_fliph12, img_fliph13, img_fliph23, img_fliph123=self.img_add_b_n_f(img_fliph)
return imgfile, img_fliph, img_fliph12, img_fliph13, img_fliph23, img_fliph123
def flipv_func(self,imgfile):
img_flipv = io.imread(imgfile)
img_flipv = img_flipv[::-1,:,:]
img_flipv12, img_flipv13, img_flipv23, img_flipv123=self.img_add_b_n_f(img_flipv)
return imgfile, img_flipv, img_flipv12, img_flipv13, img_flipv23,img_flipv123
def rescale_func(self,imgpath):
img_rsl =io.imread(imgpath)
img_rsl = img_as_ubyte(transform.rescale(img_rsl,0.8))
img_rsl12, img_rsl13, img_rsl23, img_rsl123 = self.img_add_b_n_f(img_rsl)
return imgpath, img_rsl, img_rsl12, img_rsl13, img_rsl23, img_rsl123
def crop_func(self,imgpath):
img_crop = io.imread(imgpath)
height,width = img_crop.shape[0],img_crop.shape[1]
crop_w = width // 10
crop_h =height//10
img_crop = img_crop[crop_h:, crop_w:, :]
img_crop12, img_crop13, img_crop23, img_crop123 = self.img_add_b_n_f(img_crop)
return imgpath, img_crop, img_crop12,img_crop13, img_crop23, img_crop123
def img_add_b_n_f(self,img):
img_add12 = img.copy()
img_add12 = exposure.adjust_gamma(img_add12, random.uniform(0.5, 1.5))
height, width, depth = img_add12.shape
num = int(width * height * 0.005)
if random.randint(0, 1) == 0:
img_add12 =img_add12.astype(np.int16)
img_add12 = img_add12+np.random.normal(0,random.randint(5,15),[height,width,depth])
img_add12[img_add12 > 255] = 255
img_add12[img_add12 < 0] = 0
img_add12 = img_add12.astype(np.uint8)
else:
for i in range(num):
x = np.random.randint(0, width - 1)
y = np.random.randint(0, height - 1)
if random.randint(0, 1) == 0:
img_add12[y, x] = 255
else:
img_add12[y, x] = 0
# 随机生成明暗,模糊图像
img_add13 = img.copy()
img_add13 = exposure.adjust_gamma(img_add13, random.uniform(0.5, 1.5))
img_add13 = filters.gaussian(img_add13, random.uniform(1, 1.5))
img_add13[img_add13>1]=1
img_add13[img_add13<-1]=-1
# 随机生成噪音,模糊图像
img_add23 = img.copy()
img_add23 = filters.gaussian(img_add23, random.uniform(1, 1.5))
img_add23[img_add23>1]=1
img_add23[img_add23<-1]=-1
img_add23 = img_as_ubyte(img_add23)
if random.randint(0, 1) == 0:
img_add23 = img_add23.astype(np.int16)
img_add23 = img_add23 + np.random.normal(0, 10, [height, width, depth])
img_add23[img_add23 > 255] = 255
img_add23[img_add23 < 0] = 0
# img_add23 = img_as_ubyte(img_add23)
img_add23 = img_add23.astype(np.uint8)
else:
for i in range(num):
x = np.random.randint(0, width - 1)
y = np.random.randint(0, height - 1)
if random.randint(0, 1) == 0:
img_add23[y, x] = 255
else:
img_add23[y, x] = 0
# 随机生成明暗,噪音,模糊图像
img_add123 = img.copy()
img_add123 = exposure.adjust_gamma(img_add123, random.uniform(0.5, 1.5))
img_add123 = filters.gaussian(img_add123, random.uniform(1, 1.5))
img_add123[img_add123>1]=1
img_add123[img_add123<-1]=-1
img_add123 = img_as_ubyte(img_add123)
if random.randint(0, 1) == 0:
img_add123 = img_add123.astype(np.int16)
img_add123 = img_add123 + np.random.normal(0, 10, [height, width, depth])
img_add123[img_add123 > 255] = 255
img_add123[img_add123 < 0] = 0
img_add123 = img_add123.astype(np.uint8)
else:
for i in range(num):
x = np.random.randint(0, width - 1)
y = np.random.randint(0, height - 1)
if random.randint(0, 1) == 0:
img_add123[y, x] = 255
else:
img_add123[y, x] = 0
return img_add12,img_add13,img_add23,img_add123
class XmlAugmentation:
def __init__(self, xmlpath, xml_save_path):
self.xmlpath = xmlpath
self.xml_save_path = xml_save_path
def gen(self, auglist):
if 1 in auglist:
self.xml_same_size(["bnh_","bnl_"])
if 2 in auglist:
self.xml_same_size(["ns_","gns_"])
if 3 in auglist:
self.xml_same_size(["fil1_","fil2_"])
if 4 in auglist:
self.xml_same_size(["add12_","add13_","add23_","add123_"])
if 5 in auglist:
self.xml_fliph(["fliph_","fliph12_","fliph13_","fliph23_","fliph123_"])
if 6 in auglist:
self.xml_flipv(["flipv_","flipv12_","flipv13_","flipv23_","flipv123_"])
if 7 in auglist:
self.xml_rescale(["rsl_","rsl12_","rsl13_","rsl23_","rsl123_"])
if 8 in auglist:
self.xml_crop(["crop_","crop12_","crop13_","crop23_","crop123_"])
def xml_same_size(self,str_names):
for xmlfile in os.listdir(self.xmlpath):
for str_name in str_names:
dom = xml.dom.minidom.parse(os.path.join(self.xmlpath, xmlfile))
root = dom.documentElement
filename = root.getElementsByTagName("filename")
filename_str = filename[0].firstChild.data
newname = str_name + filename_str
filename[0].firstChild.data = newname
savename = os.path.join(self.xml_save_path, str_name + xmlfile)
# 坐标在随机三个像素变动
objects = root.getElementsByTagName("object")
width = int(root.getElementsByTagName("width")[0].firstChild.data)
height = int(root.getElementsByTagName("height")[0].firstChild.data)
self.random_bndbox(objects, width, height)
with open(savename, "w") as f:
dom.writexml(f)
print(savename, "is save")
def xml_fliph(self,str_names):
for xmlfile in os.listdir(self.xmlpath):
for str_name in str_names:
dom = xml.dom.minidom.parse(os.path.join(self.xmlpath, xmlfile))
root = dom.documentElement
filename = root.getElementsByTagName("filename")
filename_str = filename[0].firstChild.data
newname = str_name + filename_str
filename[0].firstChild.data = newname
savename = os.path.join(self.xml_save_path, str_name + xmlfile)
# 坐标在随机三个像素变动
objects = root.getElementsByTagName("object")
width = int(root.getElementsByTagName("width")[0].firstChild.data)
height = int(root.getElementsByTagName("height")[0].firstChild.data)
self.fliph_random_bndbox(objects, width, height)
with open(savename, "w") as f:
dom.writexml(f)
print(savename, "is save")
def fliph_random_bndbox(self,objects,width_value,height_value):
for i in range(len(objects)):
xmin = objects[i].getElementsByTagName("xmin")
ymin = objects[i].getElementsByTagName("ymin")
ymin[0].firstChild.data = max(int(ymin[0].firstChild.data) - random.randint(0, 3), 0)
xmax = objects[i].getElementsByTagName("xmax")
xmin[0].firstChild.data, xmax[0].firstChild.data = max(
width_value - int(xmax[0].firstChild.data) - random.randint(0, 3), 0), min(
width_value - int(xmin[0].firstChild.data) + random.randint(0, 3), width_value)
ymax = objects[i].getElementsByTagName("ymax")
ymax[0].firstChild.data = min(int(ymax[0].firstChild.data) + random.randint(0, 3), height_value)
def xml_flipv(self,str_names):
for xmlfile in os.listdir(self.xmlpath):
for str_name in str_names:
dom = xml.dom.minidom.parse(os.path.join(self.xmlpath, xmlfile))
root = dom.documentElement
filename = root.getElementsByTagName("filename")
filename_str = filename[0].firstChild.data
newname = str_name + filename_str
filename[0].firstChild.data = newname
savename = os.path.join(self.xml_save_path, str_name + xmlfile)
# 坐标在随机三个像素变动
objects = root.getElementsByTagName("object")
width = int(root.getElementsByTagName("width")[0].firstChild.data)
height = int(root.getElementsByTagName("height")[0].firstChild.data)
self.flipv_random_bndbox(objects, width, height)
with open(savename, "w") as f:
dom.writexml(f)
print(savename, "is save")
def xml_rescale(self,str_names):
for xmlfile in os.listdir(self.xmlpath):
for str_name in str_names:
dom = xml.dom.minidom.parse(os.path.join(self.xmlpath, xmlfile))
root = dom.documentElement
filename = root.getElementsByTagName("filename")
filename_str = filename[0].firstChild.data
newname = str_name + filename_str
filename[0].firstChild.data = newname
savename = os.path.join(self.xml_save_path, str_name + xmlfile)
#修改尺寸
width = round(int(root.getElementsByTagName("width")[0].firstChild.data)*0.8)
height = round(int(root.getElementsByTagName("height")[0].firstChild.data)*0.8)
root.getElementsByTagName("width")[0].firstChild.data =width
root.getElementsByTagName("height")[0].firstChild.data =height
# 坐标缩放0.8在随机三个像素变动
objects = root.getElementsByTagName("object")
self.rescale_random_bndbox(objects, width, height)
with open(savename, "w") as f:
dom.writexml(f)
print(savename, "is save")
def xml_crop(self,str_names):
for xmlfile in os.listdir(self.xmlpath):
for str_name in str_names:
dom = xml.dom.minidom.parse(os.path.join(self.xmlpath, xmlfile))
root = dom.documentElement
filename = root.getElementsByTagName("filename")
filename_str = filename[0].firstChild.data
newname = str_name + filename_str
filename[0].firstChild.data = newname
savename = os.path.join(self.xml_save_path, str_name + xmlfile)
#修改尺寸
wid = int(root.getElementsByTagName("width")[0].firstChild.data)
hei = int(root.getElementsByTagName("height")[0].firstChild.data)
crop_w = wid//10
crop_h =hei//10
width = wid - crop_w
height =hei-crop_h
root.getElementsByTagName("width")[0].firstChild.data =width
root.getElementsByTagName("height")[0].firstChild.data =height
# 坐标缩放0.8在随机三个像素变动
objects = root.getElementsByTagName("object")
self.crop_random_bndbox(objects, crop_w, crop_h,root)
with open(savename, "w") as f:
dom.writexml(f)
print(savename, "is save")
def crop_random_bndbox(self,objects, crop_w, crop_h,root):
i_list = []
for i in range(len(objects)):
xmin = objects[i].getElementsByTagName("xmin")
xmin_value = int(xmin[0].firstChild.data) - crop_w
if xmin_value < 0:
i_list.append(i)
continue
ymin = objects[i].getElementsByTagName("ymin")
ymin_value = int(ymin[0].firstChild.data) - crop_h
if ymin_value < 0:
i_list.append(i)
continue
xmin[0].firstChild.data = xmin_value
xmax = objects[i].getElementsByTagName("xmax")
xmax[0].firstChild.data = int(xmax[0].firstChild.data) - crop_w
ymin[0].firstChild.data = ymin_value
ymax = objects[i].getElementsByTagName("ymax")
ymax[0].firstChild.data = int(ymax[0].firstChild.data) - crop_h
if len(i_list) != 0:
for idx in i_list:
root.removeChild(objects[idx])
def rescale_random_bndbox(self,objects, width_value, height_value):
for i in range(len(objects)):
xmin = objects[i].getElementsByTagName("xmin")
xmin[0].firstChild.data = max(int(int(xmin[0].firstChild.data) * 0.8) - random.randint(0, 3), 0)
ymin = objects[i].getElementsByTagName("ymin")
ymin[0].firstChild.data = max(int(int(ymin[0].firstChild.data) * 0.8) - random.randint(0, 3), 0)
xmax = objects[i].getElementsByTagName("xmax")
xmax[0].firstChild.data = min(int(int(xmax[0].firstChild.data) * 0.8) + random.randint(0, 3),width_value )
ymax = objects[i].getElementsByTagName("ymax")
ymax[0].firstChild.data = min(int(int(ymax[0].firstChild.data) * 0.8) + random.randint(0, 3),height_value )
def flipv_random_bndbox(self,objects, width_value, height_value):
for i in range(len(objects)):
xmin = objects[i].getElementsByTagName("xmin")
xmin[0].firstChild.data = max(int(xmin[0].firstChild.data) - random.randint(0, 3), 0)
ymin = objects[i].getElementsByTagName("ymin")
xmax = objects[i].getElementsByTagName("xmax")
xmax[0].firstChild.data = min(int(xmax[0].firstChild.data) + random.randint(0, 3), width_value)
ymax = objects[i].getElementsByTagName("ymax")
ymin[0].firstChild.data, ymax[0].firstChild.data = max(0, height_value - int(
ymax[0].firstChild.data) - random.randint(0, 3)), min(
height_value - int(ymin[0].firstChild.data) + random.randint(0, 3), height_value)
def random_bndbox(self, xml_objects, width, height):
RANDOM_SIZE = 3
for i in range(len(xml_objects)):
xmin = xml_objects[i].getElementsByTagName("xmin")
xmin[0].firstChild.data = max(int(xmin[0].firstChild.data) - random.randint(0, RANDOM_SIZE), 0)
ymin = xml_objects[i].getElementsByTagName("ymin")
ymin[0].firstChild.data = max(int(ymin[0].firstChild.data) - random.randint(0, RANDOM_SIZE), 0)
xmax = xml_objects[i].getElementsByTagName("xmax")
xmax[0].firstChild.data = min(int(xmax[0].firstChild.data) + random.randint(0, RANDOM_SIZE), width)
ymax = xml_objects[i].getElementsByTagName("ymax")
ymax[0].firstChild.data = min(int(ymax[0].firstChild.data) + random.randint(0, RANDOM_SIZE), height)
def inspect_xml(imgpath,xmlpath):
all_img = os.listdir(imgpath)
for xmlfile in os.listdir(xmlpath):
dom = xml.dom.minidom.parse(os.path.join(xmlpath, xmlfile))
root = dom.documentElement
filename = root.getElementsByTagName("filename")[0]
filename_value = filename.firstChild.data
if filename_value not in all_img:
os.remove(os.path.join(xmlpath,xmlfile))
print(os.path.join(xmlpath,xmlfile)," is remove"," by not in ")
continue
objects = root.getElementsByTagName("object")
if len(objects)== 0:
os.remove(os.path.join(xmlpath, xmlfile))
print(os.path.join(xmlpath, xmlfile), " is remove", " by not objects ")
continue
depth = root.getElementsByTagName("depth")[0]
depth_value = int(depth.firstChild.data)
width = root.getElementsByTagName("width")[0]
width_value = int(width.firstChild.data)
height = root.getElementsByTagName("height")[0]
height_value = int(height.firstChild.data)
if depth_value != 3 or width_value==0 or height_value ==0:
os.remove(os.path.join(xmlpath, xmlfile))
print(os.path.join(xmlpath, xmlfile), " is remove", " by depth ")
continue
bndbox = root.getElementsByTagName("object")
for i in range(len(bndbox)):
xmin = int(bndbox[i].getElementsByTagName("xmin")[0].firstChild.data)
ymin = int(bndbox[i].getElementsByTagName("ymin")[0].firstChild.data)
xmax = int(bndbox[i].getElementsByTagName("xmax")[0].firstChild.data)
ymax = int(bndbox[i].getElementsByTagName("ymax")[0].firstChild.data)
if xmin >= xmax or ymin >= ymax:
os.remove(os.path.join(xmlpath, xmlfile))
print(os.path.join(xmlpath, xmlfile), " is remove", " by boxes ")
break
def main(imgpath, img_save_path, xmlpath, xml_save_path, auglist):
imgaug = ImgAugmentation(imgpath, img_save_path)
imgaug.gen(auglist)
xmlaug = XmlAugmentation(xmlpath, xml_save_path)
xmlaug.gen(auglist)
inspect_xml(imgpath,xmlpath)
inspect_xml(img_save_path,xml_save_path)
if __name__ == '__main__':
imgpath = r"C:\self\others\data-augmentation-master\x-image"
img_save_path = r"C:\self\others\data-augmentation-master\x-gen"
xmlpath = r"C:\self\others\data-augmentation-master\xmlfile"
xml_save_path = r"C:\self\others\data-augmentation-master\xmlfilegen"
auglist = [1,2,3,4,5,6,7,8]
"""
auglist =[1,2,3,4,5,6,7,8]
1:调亮暗
2:加噪音
3:模糊
4: 1,2,3组合(12,13,23,123)
5:水平翻转,翻转后调亮暗,加噪音,模糊
6:垂直翻转,,翻转后调亮暗,加噪音,模糊
7:缩放0.8,后调亮暗,加噪音,模糊
8:裁剪左上10%,后调亮暗,加噪音,模糊
"""
main(imgpath, img_save_path, xmlpath, xml_save_path, auglist)