python skimage数据增强,

记录下自己使用skimage做数据增强是所遇到的一些情况。

实现的功能:对图片实现调明暗,模糊,噪音,裁剪,缩放,翻转及其组合,并生成对应使用labelimg生成的标注xml文件。

遇到的问题:1生成高斯噪音时,图像上有很明显的杂色聚集

原因:像素值在加高斯噪音 img+random.gauss(mean,sigma)时,部分像素小于0,大于255.

解决办法:把图像先转成 int16  ,在做截断 img[img>255]=255,img[img<0] = 0.

问题2:有部分图像保存失败。

原因:再调用函数fiter.gauss()时,数据转成了float,skimage中float限制在(-1,1),部分图片超界导致不成功。

解决办法:数据截断 img[img>1]=1,img[img<-1] = -1

问题3:保存的图片成纯白或纯黑

原因:在组合增强方式的时候,有函数返回dtype为unit8,float64,进行了错误的截断。

解决办法:打印出每个返回图像数据的dtype,对应做截断。

"""
数据增强实现功能:定义生成图片类,实现图片数据增强,对应定义一个生成XML类,实现labelimg标注的图片XML生成,使用时设置
aug_list参数选择要进行的增强内容,设置原图,保存图片,原XML,保存XML路径

"""
import skimage
from skimage import io, transform, filters
from skimage import color, data_dir, data
from skimage import exposure, img_as_float,img_as_ubyte,img_as_int,dtype_limits
import random
import cv2
import numpy as np
import os
import xml.dom.minidom

class ImgAugmentation:
    def __init__(self, imgpath, img_save_path):
        self.imgpath = imgpath
        self.img_save_path = img_save_path

    def gen(self, auglist):
        if 1 in auglist:
            self.img_gen(1,["bnh_","bnl_"])#调明暗self.brightness_func
        if 2 in auglist:
            self.img_gen(2,["ns_","gns_"])#调噪音self.noise_func
        if 3 in auglist:
            self.img_gen(3,["fil1_","fil2_"])#滤波self.filter_func
        if 4 in auglist:
            self.img_gen(4,["add12_","add13_","add23_","add123_"])#组合self.add_func
        if 5 in auglist:
            self.img_gen(5,["fliph_","fliph12_","fliph13_","fliph23_","fliph123_"])#水平翻转self.fliph_func
        if 6 in auglist:
            self.img_gen(6,["flipv_","flipv12_","flipv13_","flipv23_","flipv123_"])#垂直翻转self.flipv_func
        if 7 in auglist:
            self.img_gen(7,["rsl_","rsl12_","rsl13_","rsl23_","rsl123_"])#缩放self.rescale_func
        if 8 in auglist:
            self.img_gen(8,["crop_","crop12_","crop13_","crop23_","crop123_"])#裁剪self.crop_func

    def img_gen(self,load_func_idx,str_lists):
        strs = self.imgpath + "\\*.jpg"
        if load_func_idx ==1:
            coll = io.ImageCollection(strs, load_func=self.brightness_func)
        if load_func_idx ==2:
            coll = io.ImageCollection(strs, load_func=self.noise_func)
        if load_func_idx ==3:
            coll = io.ImageCollection(strs, load_func=self.filter_func)
        if load_func_idx ==4:
            coll = io.ImageCollection(strs, load_func=self.add_func)
        if load_func_idx ==5:
            coll = io.ImageCollection(strs, load_func=self.fliph_func)
        if load_func_idx ==6:
            coll = io.ImageCollection(strs, load_func=self.flipv_func)
        if load_func_idx ==7:
            coll = io.ImageCollection(strs, load_func=self.rescale_func)
        if load_func_idx ==8:
            coll = io.ImageCollection(strs, load_func=self.crop_func)
        for i in range(len(coll)):
            imgpath= coll[i][0]
            for j in range(len(str_lists)):
                try:
                    imgname = os.path.split(imgpath)[1]
                    io.imsave(os.path.join(self.img_save_path, str_lists[j] + imgname),coll[i][j+1] )
                    print(os.path.join(self.img_save_path, str_lists[j] + imgname), "is save")
                except Exception as e :
                    print(os.path.join(self.img_save_path, str_lists[j] + imgname)+ " is error to save")
                    print(e)
    def brightness_func(self, imgfile):
        img = io.imread(imgfile)
        img1 = exposure.adjust_gamma(img, 0.5)
        img2 = exposure.adjust_gamma(img, 1.5)
        return imgfile, img1, img2

    def noise_func(self, imgfile):
        img_noise = io.imread(imgfile)
        img_gaussnoise = io.imread(imgfile).astype(np.int16)
        height, width, depth = img_noise.shape
        num = int(width * height * 0.005)
        for i in range(num):
            x = np.random.randint(0, width - 1)
            y = np.random.randint(0, height - 1)
            if random.randint(0, 1) == 0:
                img_noise[y, x] = 255
            else:
                img_noise[y, x] = 0
        img_gaussnoise = img_gaussnoise+np.random.normal(0,10,[height,width,depth])
        img_gaussnoise[img_gaussnoise > 255] = 255
        img_gaussnoise[img_gaussnoise < 0] = 0
        img_gaussnoise = img_gaussnoise.astype(np.uint8)
        return imgfile,img_noise,img_gaussnoise

    def filter_func(self,imgfile):
        img = io.imread(imgfile)
        img1 =filters.gaussian(img,1)
        img1[img1<-1] =-1
        img1[img1>1] =1
        img2 = filters.gaussian(img,1.5)
        img2[img2 < -1] = -1
        img2[img2 > 1] = 1
        return imgfile,img1,img2

    def add_func(self,imgfile):
        #生成随机明暗,噪音的图片
        img = io.imread(imgfile)
        img_add12, img_add13, img_add23, img_add123=self.img_add_b_n_f(img)
        return imgfile,img_add12,img_add13,img_add23,img_add123

    def fliph_func(self,imgfile):
        img_fliph = io.imread(imgfile)
        img_fliph = img_fliph[:,::-1,:]
        img_fliph12, img_fliph13, img_fliph23, img_fliph123=self.img_add_b_n_f(img_fliph)
        return imgfile, img_fliph, img_fliph12, img_fliph13, img_fliph23, img_fliph123

    def flipv_func(self,imgfile):
        img_flipv = io.imread(imgfile)
        img_flipv = img_flipv[::-1,:,:]
        img_flipv12, img_flipv13, img_flipv23, img_flipv123=self.img_add_b_n_f(img_flipv)
        return imgfile, img_flipv, img_flipv12, img_flipv13, img_flipv23,img_flipv123

    def rescale_func(self,imgpath):
        img_rsl =io.imread(imgpath)
        img_rsl = img_as_ubyte(transform.rescale(img_rsl,0.8))
        img_rsl12, img_rsl13, img_rsl23, img_rsl123 = self.img_add_b_n_f(img_rsl)
        return imgpath, img_rsl, img_rsl12, img_rsl13, img_rsl23, img_rsl123

    def crop_func(self,imgpath):
        img_crop = io.imread(imgpath)
        height,width = img_crop.shape[0],img_crop.shape[1]
        crop_w = width // 10
        crop_h =height//10
        img_crop = img_crop[crop_h:, crop_w:, :]
        img_crop12, img_crop13, img_crop23, img_crop123 = self.img_add_b_n_f(img_crop)
        return imgpath, img_crop, img_crop12,img_crop13, img_crop23, img_crop123

    def img_add_b_n_f(self,img):
        img_add12 = img.copy()
        img_add12 = exposure.adjust_gamma(img_add12, random.uniform(0.5, 1.5))
        height, width, depth = img_add12.shape
        num = int(width * height * 0.005)
        if random.randint(0, 1) == 0:
            img_add12 =img_add12.astype(np.int16)
            img_add12 = img_add12+np.random.normal(0,random.randint(5,15),[height,width,depth])
            img_add12[img_add12 > 255] = 255
            img_add12[img_add12 < 0] = 0
            img_add12 = img_add12.astype(np.uint8)
        else:
            for i in range(num):
                x = np.random.randint(0, width - 1)
                y = np.random.randint(0, height - 1)
                if random.randint(0, 1) == 0:
                    img_add12[y, x] = 255
                else:
                    img_add12[y, x] = 0
        # 随机生成明暗,模糊图像
        img_add13 = img.copy()
        img_add13 = exposure.adjust_gamma(img_add13, random.uniform(0.5, 1.5))
        img_add13 = filters.gaussian(img_add13, random.uniform(1, 1.5))
        img_add13[img_add13>1]=1
        img_add13[img_add13<-1]=-1

        # 随机生成噪音,模糊图像

        img_add23 = img.copy()
        img_add23 = filters.gaussian(img_add23, random.uniform(1, 1.5))

        img_add23[img_add23>1]=1
        img_add23[img_add23<-1]=-1
        img_add23 = img_as_ubyte(img_add23)
        if random.randint(0, 1) == 0:
            img_add23 = img_add23.astype(np.int16)
            img_add23 = img_add23 + np.random.normal(0, 10, [height, width, depth])
            img_add23[img_add23 > 255] = 255
            img_add23[img_add23 < 0] = 0
            # img_add23 = img_as_ubyte(img_add23)
            img_add23 = img_add23.astype(np.uint8)
        else:
            for i in range(num):
                x = np.random.randint(0, width - 1)
                y = np.random.randint(0, height - 1)
                if random.randint(0, 1) == 0:
                    img_add23[y, x] = 255
                else:
                    img_add23[y, x] = 0

        # 随机生成明暗,噪音,模糊图像
        img_add123 = img.copy()
        img_add123 = exposure.adjust_gamma(img_add123, random.uniform(0.5, 1.5))
        img_add123 = filters.gaussian(img_add123, random.uniform(1, 1.5))
        img_add123[img_add123>1]=1
        img_add123[img_add123<-1]=-1
        img_add123 = img_as_ubyte(img_add123)
        if random.randint(0, 1) == 0:
            img_add123 = img_add123.astype(np.int16)
            img_add123 = img_add123 + np.random.normal(0, 10, [height, width, depth])
            img_add123[img_add123 > 255] = 255
            img_add123[img_add123 < 0] = 0
            img_add123 = img_add123.astype(np.uint8)
        else:
            for i in range(num):
                x = np.random.randint(0, width - 1)
                y = np.random.randint(0, height - 1)
                if random.randint(0, 1) == 0:
                    img_add123[y, x] = 255
                else:
                    img_add123[y, x] = 0
        return img_add12,img_add13,img_add23,img_add123

class XmlAugmentation:
    def __init__(self, xmlpath, xml_save_path):
        self.xmlpath = xmlpath
        self.xml_save_path = xml_save_path

    def gen(self, auglist):
        if 1 in auglist:
            self.xml_same_size(["bnh_","bnl_"])
        if 2 in auglist:
            self.xml_same_size(["ns_","gns_"])
        if 3 in auglist:
            self.xml_same_size(["fil1_","fil2_"])
        if 4 in auglist:
            self.xml_same_size(["add12_","add13_","add23_","add123_"])
        if 5 in auglist:
            self.xml_fliph(["fliph_","fliph12_","fliph13_","fliph23_","fliph123_"])
        if 6 in auglist:
            self.xml_flipv(["flipv_","flipv12_","flipv13_","flipv23_","flipv123_"])
        if 7 in auglist:
            self.xml_rescale(["rsl_","rsl12_","rsl13_","rsl23_","rsl123_"])
        if 8 in auglist:
            self.xml_crop(["crop_","crop12_","crop13_","crop23_","crop123_"])
    def xml_same_size(self,str_names):
        for xmlfile in os.listdir(self.xmlpath):
            for str_name in str_names:
                dom = xml.dom.minidom.parse(os.path.join(self.xmlpath, xmlfile))
                root = dom.documentElement
                filename = root.getElementsByTagName("filename")
                filename_str = filename[0].firstChild.data
                newname = str_name + filename_str
                filename[0].firstChild.data = newname
                savename = os.path.join(self.xml_save_path, str_name + xmlfile)

                # 坐标在随机三个像素变动
                objects = root.getElementsByTagName("object")
                width = int(root.getElementsByTagName("width")[0].firstChild.data)
                height = int(root.getElementsByTagName("height")[0].firstChild.data)
                self.random_bndbox(objects, width, height)
                with open(savename, "w") as f:
                    dom.writexml(f)
                    print(savename, "is save")
    def xml_fliph(self,str_names):
        for xmlfile in os.listdir(self.xmlpath):
            for str_name in str_names:
                dom = xml.dom.minidom.parse(os.path.join(self.xmlpath, xmlfile))
                root = dom.documentElement
                filename = root.getElementsByTagName("filename")
                filename_str = filename[0].firstChild.data
                newname = str_name + filename_str
                filename[0].firstChild.data = newname
                savename = os.path.join(self.xml_save_path, str_name + xmlfile)

                # 坐标在随机三个像素变动
                objects = root.getElementsByTagName("object")
                width = int(root.getElementsByTagName("width")[0].firstChild.data)
                height = int(root.getElementsByTagName("height")[0].firstChild.data)
                self.fliph_random_bndbox(objects, width, height)
                with open(savename, "w") as f:
                    dom.writexml(f)
                    print(savename, "is save")
    def fliph_random_bndbox(self,objects,width_value,height_value):
        for i in range(len(objects)):
            xmin = objects[i].getElementsByTagName("xmin")
            ymin = objects[i].getElementsByTagName("ymin")
            ymin[0].firstChild.data = max(int(ymin[0].firstChild.data) - random.randint(0, 3), 0)
            xmax = objects[i].getElementsByTagName("xmax")

            xmin[0].firstChild.data, xmax[0].firstChild.data = max(
                width_value - int(xmax[0].firstChild.data) - random.randint(0, 3), 0), min(
                width_value - int(xmin[0].firstChild.data) + random.randint(0, 3), width_value)

            ymax = objects[i].getElementsByTagName("ymax")
            ymax[0].firstChild.data = min(int(ymax[0].firstChild.data) + random.randint(0, 3), height_value)
    def xml_flipv(self,str_names):
        for xmlfile in os.listdir(self.xmlpath):
            for str_name in str_names:
                dom = xml.dom.minidom.parse(os.path.join(self.xmlpath, xmlfile))
                root = dom.documentElement
                filename = root.getElementsByTagName("filename")
                filename_str = filename[0].firstChild.data
                newname = str_name + filename_str
                filename[0].firstChild.data = newname
                savename = os.path.join(self.xml_save_path, str_name + xmlfile)

                # 坐标在随机三个像素变动
                objects = root.getElementsByTagName("object")
                width = int(root.getElementsByTagName("width")[0].firstChild.data)
                height = int(root.getElementsByTagName("height")[0].firstChild.data)
                self.flipv_random_bndbox(objects, width, height)
                with open(savename, "w") as f:
                    dom.writexml(f)
                    print(savename, "is save")
    def xml_rescale(self,str_names):
        for xmlfile in os.listdir(self.xmlpath):
            for str_name in str_names:
                dom = xml.dom.minidom.parse(os.path.join(self.xmlpath, xmlfile))
                root = dom.documentElement
                filename = root.getElementsByTagName("filename")
                filename_str = filename[0].firstChild.data
                newname = str_name + filename_str
                filename[0].firstChild.data = newname
                savename = os.path.join(self.xml_save_path, str_name + xmlfile)
                #修改尺寸
                width = round(int(root.getElementsByTagName("width")[0].firstChild.data)*0.8)
                height = round(int(root.getElementsByTagName("height")[0].firstChild.data)*0.8)
                root.getElementsByTagName("width")[0].firstChild.data =width
                root.getElementsByTagName("height")[0].firstChild.data =height
                # 坐标缩放0.8在随机三个像素变动
                objects = root.getElementsByTagName("object")
                self.rescale_random_bndbox(objects, width, height)
                with open(savename, "w") as f:
                    dom.writexml(f)
                    print(savename, "is save")
    def xml_crop(self,str_names):
        for xmlfile in os.listdir(self.xmlpath):
            for str_name in str_names:
                dom = xml.dom.minidom.parse(os.path.join(self.xmlpath, xmlfile))
                root = dom.documentElement
                filename = root.getElementsByTagName("filename")
                filename_str = filename[0].firstChild.data
                newname = str_name + filename_str
                filename[0].firstChild.data = newname
                savename = os.path.join(self.xml_save_path, str_name + xmlfile)
                #修改尺寸
                wid = int(root.getElementsByTagName("width")[0].firstChild.data)
                hei = int(root.getElementsByTagName("height")[0].firstChild.data)
                crop_w = wid//10
                crop_h =hei//10
                width = wid - crop_w
                height =hei-crop_h
                root.getElementsByTagName("width")[0].firstChild.data =width
                root.getElementsByTagName("height")[0].firstChild.data =height
                # 坐标缩放0.8在随机三个像素变动
                objects = root.getElementsByTagName("object")
                self.crop_random_bndbox(objects, crop_w, crop_h,root)
                with open(savename, "w") as f:
                    dom.writexml(f)
                    print(savename, "is save")
    def crop_random_bndbox(self,objects, crop_w, crop_h,root):
        i_list = []
        for i in range(len(objects)):
            xmin = objects[i].getElementsByTagName("xmin")
            xmin_value = int(xmin[0].firstChild.data) - crop_w
            if xmin_value < 0:
                i_list.append(i)
                continue
            ymin = objects[i].getElementsByTagName("ymin")
            ymin_value = int(ymin[0].firstChild.data) - crop_h
            if ymin_value < 0:
                i_list.append(i)
                continue

            xmin[0].firstChild.data = xmin_value
            xmax = objects[i].getElementsByTagName("xmax")
            xmax[0].firstChild.data = int(xmax[0].firstChild.data) - crop_w

            ymin[0].firstChild.data = ymin_value
            ymax = objects[i].getElementsByTagName("ymax")
            ymax[0].firstChild.data = int(ymax[0].firstChild.data) - crop_h

        if len(i_list) != 0:
            for idx in i_list:
                root.removeChild(objects[idx])
    def rescale_random_bndbox(self,objects, width_value, height_value):
        for i in range(len(objects)):
            xmin = objects[i].getElementsByTagName("xmin")
            xmin[0].firstChild.data = max(int(int(xmin[0].firstChild.data) * 0.8) - random.randint(0, 3), 0)
            ymin = objects[i].getElementsByTagName("ymin")
            ymin[0].firstChild.data = max(int(int(ymin[0].firstChild.data) * 0.8) - random.randint(0, 3), 0)
            xmax = objects[i].getElementsByTagName("xmax")
            xmax[0].firstChild.data = min(int(int(xmax[0].firstChild.data) * 0.8) + random.randint(0, 3),width_value )
            ymax = objects[i].getElementsByTagName("ymax")
            ymax[0].firstChild.data = min(int(int(ymax[0].firstChild.data) * 0.8) + random.randint(0, 3),height_value )

    def flipv_random_bndbox(self,objects, width_value, height_value):
        for i in range(len(objects)):
            xmin = objects[i].getElementsByTagName("xmin")
            xmin[0].firstChild.data = max(int(xmin[0].firstChild.data) - random.randint(0, 3), 0)
            ymin = objects[i].getElementsByTagName("ymin")

            xmax = objects[i].getElementsByTagName("xmax")
            xmax[0].firstChild.data = min(int(xmax[0].firstChild.data) + random.randint(0, 3), width_value)

            ymax = objects[i].getElementsByTagName("ymax")
            ymin[0].firstChild.data, ymax[0].firstChild.data = max(0, height_value - int(
                ymax[0].firstChild.data) - random.randint(0, 3)), min(
                height_value - int(ymin[0].firstChild.data) + random.randint(0, 3), height_value)
    def random_bndbox(self, xml_objects, width, height):
        RANDOM_SIZE = 3
        for i in range(len(xml_objects)):
            xmin = xml_objects[i].getElementsByTagName("xmin")
            xmin[0].firstChild.data = max(int(xmin[0].firstChild.data) - random.randint(0, RANDOM_SIZE), 0)
            ymin = xml_objects[i].getElementsByTagName("ymin")
            ymin[0].firstChild.data = max(int(ymin[0].firstChild.data) - random.randint(0, RANDOM_SIZE), 0)
            xmax = xml_objects[i].getElementsByTagName("xmax")
            xmax[0].firstChild.data = min(int(xmax[0].firstChild.data) + random.randint(0, RANDOM_SIZE), width)
            ymax = xml_objects[i].getElementsByTagName("ymax")
            ymax[0].firstChild.data = min(int(ymax[0].firstChild.data) + random.randint(0, RANDOM_SIZE), height)

def inspect_xml(imgpath,xmlpath):
    all_img = os.listdir(imgpath)
    for xmlfile in os.listdir(xmlpath):
        dom = xml.dom.minidom.parse(os.path.join(xmlpath, xmlfile))
        root = dom.documentElement
        filename = root.getElementsByTagName("filename")[0]
        filename_value = filename.firstChild.data
        if filename_value not in all_img:
            os.remove(os.path.join(xmlpath,xmlfile))
            print(os.path.join(xmlpath,xmlfile)," is remove"," by not in ")
            continue
        objects = root.getElementsByTagName("object")
        if len(objects)== 0:
            os.remove(os.path.join(xmlpath, xmlfile))
            print(os.path.join(xmlpath, xmlfile), " is remove", " by not objects ")
            continue
        depth = root.getElementsByTagName("depth")[0]
        depth_value = int(depth.firstChild.data)
        width = root.getElementsByTagName("width")[0]
        width_value = int(width.firstChild.data)
        height = root.getElementsByTagName("height")[0]
        height_value = int(height.firstChild.data)
        if depth_value != 3 or width_value==0 or height_value ==0:
            os.remove(os.path.join(xmlpath, xmlfile))
            print(os.path.join(xmlpath, xmlfile), " is remove", " by depth ")
            continue
        bndbox = root.getElementsByTagName("object")

        for i in range(len(bndbox)):
            xmin = int(bndbox[i].getElementsByTagName("xmin")[0].firstChild.data)
            ymin = int(bndbox[i].getElementsByTagName("ymin")[0].firstChild.data)
            xmax = int(bndbox[i].getElementsByTagName("xmax")[0].firstChild.data)
            ymax = int(bndbox[i].getElementsByTagName("ymax")[0].firstChild.data)
            if xmin >= xmax or ymin >= ymax:
                os.remove(os.path.join(xmlpath, xmlfile))
                print(os.path.join(xmlpath, xmlfile), " is remove", " by boxes ")
                break


def main(imgpath, img_save_path, xmlpath, xml_save_path, auglist):
    imgaug = ImgAugmentation(imgpath, img_save_path)
    imgaug.gen(auglist)
    xmlaug = XmlAugmentation(xmlpath, xml_save_path)
    xmlaug.gen(auglist)
    inspect_xml(imgpath,xmlpath)
    inspect_xml(img_save_path,xml_save_path)

if __name__ == '__main__':
    imgpath = r"C:\self\others\data-augmentation-master\x-image"
    img_save_path = r"C:\self\others\data-augmentation-master\x-gen"
    xmlpath = r"C:\self\others\data-augmentation-master\xmlfile"
    xml_save_path = r"C:\self\others\data-augmentation-master\xmlfilegen"
    auglist = [1,2,3,4,5,6,7,8]
    """
    auglist =[1,2,3,4,5,6,7,8]
    1:调亮暗
    2:加噪音
    3:模糊
    4: 1,2,3组合(12,13,23,123)
    5:水平翻转,翻转后调亮暗,加噪音,模糊
    6:垂直翻转,,翻转后调亮暗,加噪音,模糊
    7:缩放0.8,后调亮暗,加噪音,模糊
    8:裁剪左上10%,后调亮暗,加噪音,模糊
    """
    main(imgpath, img_save_path, xmlpath, xml_save_path, auglist)

 

 

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值