在语义分割中数据增强的正确打开方式(tf.data)

通过一系列的
随即翻转
随机旋转
随机剪裁
随机亮度调节
随机对比度调节
随机色相调节
随机饱和度调节
随机高斯噪声
让数据集变得更强大!

class Parse(object):
	'''
	 callable的类,返回一个进行数据解析和数据增强的仿函数
		使用示例:
			def make_dataset(self,batch_size,output_shape,argumentation=False,buffer_size=4000,epochs=None,shuffle=True):
		        filename=[self.tfr_path]
			    parse=Parse(self.img_shape,output_shape,argumentation,num_classes=self.num_classes)
			    dataset=tf.data.TFRecordDataset(filename)
			    dataset=dataset.prefetch(tf.contrib.data.AUTOTUNE)
			    dataset=dataset.shuffle(buffer_size=buffer_size,seed=int(time()))
		        dataset=dataset.repeat(count=None
		        dataset=dataset.map(parse)
		        dataset=dataset.batch(batch_size=batch_size)
		        dataset=dataset.apply(tf.data.experimental.prefetch_to_device("/gpu:0"))
		        return dataset
    '''
    def __init__(self,raw_shape,out_shape,argumentation,num_classes):
        """
            返回一个Parse类的对象
            raw_shape:TFRecord文件中example的图像大小
            out_shape:随机剪裁后的图像大小
            argumentation:Bool变量,如果为0就只解析出图像并裁剪,不进行任何数据增强
            num_classes:类别总数(包括背景),用于确定one hot的维度
        """
        self.__argumantation=argumentation
        self.raw_shape=raw_shape
        self.out_shape=out_shape
        self.num_classes=num_classes
    def argumentation(self,image,labels):
        """
           单独对标签进行数据增强
           输入标签的one hot编码后的张量,输出和原图大小同样的张量
        """
        image=tf.cast(image,tf.float32)
        image,labels=self.random_crop_flip_rotate(image,labels)
        image=tf.image.random_brightness(image,max_delta=0.4)  #随机亮度调节
        image=tf.image.random_contrast(image,lower=0.7,upper=1.3)#随机对比度
        image=tf.image.random_hue(image,max_delta=0.3)#随机色相
        image=tf.image.random_saturation(image,lower=0.8,upper=1.3)#随机饱和度
        image=tf.cast(image,tf.float32)
        image=image+tf.truncated_normal(stddev=4,mean=2,shape=image.shape.as_list(),seed=int(time()))#加上高斯噪声
        image=tf.clip_by_value(image,0.0,255.0)
        return image,labels
    
    def random_rotate(self,input_image, min_angle=-np.pi/2,max_angle=np.pi/2):
        '''
        TensorFlow对图像进行随机旋转
        :param input_image: 图像输入
        :param min_angle: 最小旋转角度
        :param max_angle: 最大旋转角度
        :return: 旋转后的图像
        '''
        distorted_image = tf.expand_dims(input_image, 0)
        random_angles = tf.random.uniform(shape=(tf.shape(distorted_image)[0],), minval = min_angle , maxval = max_angle)
        distorted_image = tf.contrib.image.transform(
            distorted_image,
            tf.contrib.image.angles_to_projective_transforms(
                random_angles, tf.cast(tf.shape(distorted_image)[1], tf.float32), tf.cast(tf.shape(distorted_image)[2], tf.float32)
            ))
        rotate_image = tf.squeeze(distorted_image, [0])
        return rotate_image
    def random_crop_flip_rotate(self,image1,image2):#图片和对应标签同步进行翻转,旋转和裁剪
        image=tf.concat([image1,image2],axis=-1)
        channel=image.shape.as_list()[-1]
        shape=self.out_shape+[channel]
        print(shape)
        image=self.random_rotate(image)
        image=tf.image.random_crop(image,shape)
        image=tf.image.random_flip_left_right(image)
        image1=tf.slice(image,[0,0,0],self.out_shape+[3])
        image2=tf.slice(image,[0,0,3],self.out_shape+[-1])
        return image1,image2
    def __call__(self,tensor):
        """
            make it callable
            inputs labels 分别为图片和对应标签
            应当具有[w,h,3] 和[w,h] 的形状
        """
        feature=tf.parse_single_example(tensor,features={
            "inputs":tf.FixedLenFeature([],tf.string),
            "labels":tf.FixedLenFeature([],tf.string)
        })
        inputs=tf.decode_raw(feature["inputs"],tf.uint8)
        inputs=tf.reshape(inputs,self.raw_shape+[3])
        labels=tf.decode_raw(feature["labels"],tf.uint8)
        labels=tf.reshape(labels,self.raw_shape)
        labels=tf.one_hot(labels,self.num_classes)
        if self.__argumantation:
            inputs,labels=self.argumentation(inputs,labels)
        else:
            inputs=tf.image.resize_image_with_crop_or_pad(inputs,self.out_shape[0],self.out_shape[1])
            labels=tf.image.resize_image_with_crop_or_pad(labels,self.out_shape[0],self.out_shape[1])
        inputs=tf.image.per_image_standardization(inputs) #标准化
        return inputs,labels

#UPD:
在这里插入图片描述
初步排查了一下是random_size里面tf.image.resize_images的问题
解决起来好办
就是resize过后强行把边框几个像素丢掉

    def random_size(self,image,minratio=0.5,maxratio=2.0,prob=0.5):
        height,width=image.shape.as_list()[:2]
        min_height=height*minratio
        max_height=height*maxratio
        min_width=width*minratio
        max_width=width*maxratio
        height=tf.random_uniform(shape=[],minval=min_height,maxval=max_height)
        height=tf.cast(height,tf.int32)
        width=tf.random_uniform(shape=[],minval=min_width,maxval=max_width)
        width=tf.cast(width,tf.int32)
        _prob=tf.random_uniform(shape=[],minval=0.0,maxval=1.0)
        r_image=tf.image.resize_images(image,[height+4,width+4],method=2)
        r_image=tf.image.resize_image_with_crop_or_pad(r_image,self.out_shape[0],self.out_shape[1])
        return tf.cond(_prob<prob,lambda:r_image,lambda:image)

虽然损失了边框部分
但是问题不是很大,反而不去除的话会干扰训练过程
网络表示并不能知道为什么图片中有个边界框…

最终效果:
在这里插入图片描述

  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值