在只有一张图片,或者图片样本需要扩展时。可以使用如下代码实现图片的扩展.因为处理顺序不同,得到的图片效果会不同。所以一共有16种排序。就不一一列出。主要代码如下:
#!/usr/bin/env python # -*- coding:utf-8 -*- import tensorflow as tf import numpy as np import matplotlib.pyplot as plt #不同处理顺序得到的图片处理结果不同 def distort_color(image, color_ordering=0): if color_ordering == 0: image = tf.image.random_flip_up_down(image) #50%概率上下翻转 image = tf.image.random_flip_left_right(image) #50%概率左右翻转 image = tf.image.random_brightness(image, max_delta=32./255.) #随机亮度 image = tf.image.random_saturation(image, lower=0.5, upper=1.5) #随机饱和度 image = tf.image.random_hue(image, max_delta=0.2) #随机色相 image = tf.image.random_contrast(image, lower=0.5, upper=1.5) #随机对比度 elif color_ordering == 1: image = tf.image.random_flip_up_down(image) # 50%概率上下翻转 image = tf.image.random_brightness(image, max_delta=32. / 255.) # 随机亮度 image = tf.image.random_flip_left_right(image) # 50%概率左右翻转 image = tf.image.random_saturation(image, lower=0.5, upper=1.5) # 随机饱和度 image = tf.image.random_hue(image, max_delta=0.2) # 随机色相 image = tf.image.random_contrast(image, lower=0.5, upper=1.5) # 随机对比度 elif color_ordering == 2: image = tf.image.random_flip_up_down(image) # 50%概率上下翻转 image = tf.image.random_brightness(image, max_delta=32. / 255.) # 随机亮度 image = tf.image.random_saturation(image, lower=0.5, upper=1.5) # 随机饱和度 image = tf.image.random_flip_left_right(image) # 50%概率左右翻转 image = tf.image.random_hue(image, max_delta=0.2) # 随机色相 image = tf.image.random_contrast(image, lower=0.5, upper=1.5) # 随机对比度 return tf.clip_by_value(image, 0.0, 1.0) def preprocess_for_train(image, height, width, bbox): if bbox is None: bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) if image.dtype != tf.float32: image = tf.image.convert_image_dtype(image, dtype=tf.float32) #随机截取图像,减小需要关注的大小对图像识别的影响 #bbox_begin, bbox_size, _ = tf.image.sample_distorted_bounding_box(tf.shape(image), bounding_boxes=bbox) #distorted_image = tf.slice(image, bbox_begin, bbox_size) #将截取结果调整为神经网络的输入层大小,大小调整算法随机 image = tf.image.resize_images( image, [height, width], method=np.random.randint(4) ) #选择一种顺序 distorted_image = distort_color(image, np.random.randint(3)) return distorted_image image_raw_data = tf.gfile.FastGFile(r"C:\Users\SUM\Desktop\ceshipiture\_001a_0.png", 'rb').read() with tf.Session() as sess: img_data = tf.image.decode_png(image_raw_data) boxes = tf.constant([[[0.0, 0.0, 1.0, 1.0]]]) for i in range(10): result = preprocess_for_train(img_data, 100, 200, boxes) plt.imshow(result.eval()) plt.show()
注意的是,在指定路径时需要将其字符串化,否则会报如下错误:
解决的办法就是在路径前加‘r’以字符串化。