如何使用TensorFlow进行训练数据的准备?

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# The following functions can be used to convert a value to a type compatible with tf.Example.
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    """Returns a float_list from a float(float32) / double(float64)."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int32 / uint32 / int64 / uint64."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

# Create a dictionary with features that may be relevant.
# 高版本的TensorFlow支持解码后直接获取shape
def image_example(img_raw, label):
    img_tensor = tf.image.decode_jpeg(img_raw)
    image_shape = img_tensor.shape
    feature = {
        'height': _int64_feature(image_shape[0]),
        'width': _int64_feature(image_shape[1]),
        'depth': _int64_feature(image_shape[2]),
        'label': _int64_feature(label),
        'image_raw': _bytes_feature(img_raw),
    }
    return tf.train.Example(features=tf.train.Features(feature=feature))

# Create a dictionary with features that may be relevant.
# 低版本的TensorFlow不支持解码后直接获取shape,转成numpy.ndarray后在获取
def image_example_sess(img_raw, label, sess):
    img_tensor = tf.image.decode_jpeg(img_raw)
    with sess.as_default():
        img_data = img_tensor.eval()
        #print(type(img_data))
        image_shape = img_data.shape
    feature = {
        'height': _int64_feature(image_shape[0]),
        'width': _int64_feature(image_shape[1]),
        'depth': _int64_feature(image_shape[2]),
        'label': _int64_feature(label),
        'image_raw': _bytes_feature(img_raw),
    }
    return tf.train.Example(features=tf.train.Features(feature=feature))

'''
img_raw = tf.gfile.FastGFile('test1.jpg', 'rb').read()
label = 0
with tf.Session() as sess:
    print(image_example_sess(img_raw, label, sess))
'''
#####################################################################################################
# 写入 TFRecord 文件
# Write the raw image files to `images.tfrecords`.
# First, process the two images into `tf.Example` messages.
# Then, write to a `.tfrecords` file.
image_labels = {'1.jpg' : 0, '2.jpg': 0, '3.jpg' : 1, '4.jpg': 1,'5.jpg' : 2, '6.jpg': 2,'7.jpg' : 3, '8.jpg': 3,'9.jpg' : 4, '10.jpg': 4}
record_file = 'images.tfrecords'
with tf.Session() as sess:
    with tf.io.TFRecordWriter(record_file) as writer:
        for filename, label in image_labels.items():
            img_raw = tf.gfile.FastGFile(filename, 'rb').read()
            tf_example = image_example_sess(img_raw, label, sess)
            writer.write(tf_example.SerializeToString())

def _parse_image_function(example_proto):
    # Create a dictionary describing the features.
    image_feature_description = {
        'height': tf.io.FixedLenFeature([], tf.int64),
        'width': tf.io.FixedLenFeature([], tf.int64),
        'depth': tf.io.FixedLenFeature([], tf.int64),
        'label': tf.io.FixedLenFeature([], tf.int64),
        'image_raw': tf.io.FixedLenFeature([], tf.string),
        }
    # Parse the input tf.Example proto using the dictionary above.
    return tf.io.parse_single_example(example_proto, image_feature_description) # 返回值是<class 'dict'>类型

#####################################################################################################
# 读取 TFRecord 文件
input_files = ['images.tfrecords']  # 可以有多个文件
raw_image_dataset = tf.data.TFRecordDataset(input_files)
parsed_image_dataset = raw_image_dataset.map(_parse_image_function)
iterator = parsed_image_dataset.make_one_shot_iterator()
feature_dict = iterator.get_next()
with tf.Session() as sess:
    for i in range(len(image_labels)):
        feature_dict_val = sess.run(feature_dict)
        #print(feature_dict['height'])
        #print('height: ', feature_dict_val['height'])
        #print('width: ', feature_dict_val['width'])
        #print('depth: ', feature_dict_val['depth'])
        #print('label: ', feature_dict_val['label'])
        img = tf.io.decode_image(feature_dict_val['image_raw']).eval()
        #plt.imshow(img)
        #plt.show()

#####################################################################################################
# 读取 TFRecord 文件,文件路径由placeholder提供
input_files = tf.placeholder(tf.string)
dataset = tf.data.TFRecordDataset(input_files)
dataset = dataset.map(_parse_image_function)
# 定义遍历dataset的initializable_iterator()
iterator = dataset.make_initializable_iterator()
feature_dict = iterator.get_next()
with tf.Session() as sess:
    sess.run(iterator.initializer,feed_dict={input_files : ['images.tfrecords', 'images.tfrecords']})
    # 遍历所有数据一个epoch,遍历结束时抛出OutOfRangeError,因为在动态指定输入数据时不同数据来源的数据量大小未知,
    # 该方法使得不必提前知道数据量的确切大小
    while True:
        try:
            feature_dict_val = sess.run(feature_dict)
            #print('height: ', feature_dict_val['height'])
            #print('width: ', feature_dict_val['width'])
            #print('depth: ', feature_dict_val['depth'])
            #print('label: ', feature_dict_val['label'])
            img = tf.io.decode_image(feature_dict_val['image_raw']).eval()
            #plt.imshow(img)
            #plt.show()
        except tf.errors.OutOfRangeError:
            break

# 给定一张图像,随机调整图像的色彩,由于调整亮度、对比度、饱和度、色相的顺序会影响最后的结果,所以可以定义多种不同的顺序
# 具体使用哪种顺序可以在训练数据进行预处理时随机选择
def distort_color(img, color_ordering=0):
    if 0 == color_ordering:
        img = tf.image.random_brightness(img, max_delta=32./255.)
        img = tf.image.random_saturation(img,lower=0.5,upper=1.5)
        img = tf.image.random_hue(img,max_delta=0.2)
        img = tf.image.random_contrast(img, lower=0.5,upper=1.5)
    elif 1 == color_ordering:
        img = tf.image.random_contrast(img, lower=0.5,upper=1.5)
        img = tf.image.random_saturation(img,lower=0.5,upper=1.5)
        img = tf.image.random_hue(img,max_delta=0.2)
        img = tf.image.random_brightness(img, max_delta=32./255.)
    # 还可以定义很多……
    #elif 2 == color_ordering:
    return tf.clip_by_value(img, 0.0, 1.0) # 防止颜色变换之后值域超出[0.0,1.0]因此需要做个处理

# 给定一个解析后的example proto数据,及目标尺寸、图像标注框,此函数给出预处理,函数输入是图像识别问题中的原始训练图像,
# 输出是神经网络模型的输入层,只需处理训练数据,对于预测数据不使用随机变换等步骤
# feature_dict类型是函数tf.io.parse_single_example()的返回值类型'A dict mapping feature keys to Tensor and SparseTensor values.'
# https://tensorflow.google.cn/versions/r1.15/api_docs/python/tf/io/parse_single_example
def preprocess_for_train(feature_dict, resized_h, resized_w, boundbox):
    # 原始图像信息
    img_h = feature_dict['height']
    img_w = feature_dict['width']
    img_ch = feature_dict['depth']
    label = feature_dict['label']
    img_raw = feature_dict['image_raw']

    # 解码图像
    img_data = tf.io.decode_jpeg(img_raw)
    if img_data.dtype != tf.float32:
        img_data = tf.image.convert_image_dtype(img_data, dtype=tf.float32)

    # 如果没有标注框,则需要关注整幅图像
    if boundbox is None:
        boundbox = tf.constant([0.0,0.0,1.0,1.0],dtype=tf.float32,shape=[1,1,4])

    begin,size,bboxes = tf.image.sample_distorted_bounding_box(image_size=[img_h, img_w, img_ch],bounding_boxes=boundbox,min_object_covered=0.5)
    distorted_img = tf.slice(img_data, begin, size)
    # 将随机截取的图像调整为神经网络输入层的大小,resize方法是随机选择的
    distorted_img = tf.image.resize_images(distorted_img,[resized_h, resized_w],method=np.random.randint(4))
    # 随机左右翻转图像
    distorted_img = tf.image.random_flip_left_right(distorted_img)
    # 随机调整图像色彩
    distorted_img = distort_color(distorted_img, np.random.randint(2))

    return distorted_img, label


#####################################################################################################
PREDIFINE_HEIGHT = 120
PREDIFINE_WIDTH = 120
BATCH_SIZE = 9          # 数据batch大小
SHUFFLE_BUFFER = 10000  # 随机打乱数据时buffer的大小,越大随机效果越好但内存占用也越大
REPEAT_NUM = 5          # 将数据集重复的次数,相当于进行多次epoch,比如数据原来共10个,repeat之后变成50个
input_files = ['images.tfrecords']  # 可以有多个文件
dataset = tf.data.TFRecordDataset(input_files)
dataset = dataset.map(_parse_image_function)
# 此处表示对dataset中每个元素进行preprocess_for_train操作,此时每个元素类型是函数tf.io.parse_single_example()的返回类型即<class 'dict'>
# 也就是说经map操作后dataset中元素类型是变化的
# 处理过后dataset中每个元素类型变为函数preprocess_for_train()的返回值类型,即distorted_img, label的类型
dataset = dataset.map(lambda feat_dict: preprocess_for_train(feat_dict,PREDIFINE_HEIGHT,PREDIFINE_WIDTH,None))
dataset = dataset.shuffle(SHUFFLE_BUFFER).batch(BATCH_SIZE)
dataset = dataset.repeat(REPEAT_NUM)
iterator = dataset.make_one_shot_iterator()
img,label = iterator.get_next()   # img和label都是BATCH_SIZE个数据
with tf.Session() as sess:
    while True:
        try:
            fig = plt.figure()
            ax1 = fig.add_subplot(331)
            ax2 = fig.add_subplot(332)
            ax3 = fig.add_subplot(333)
            ax4 = fig.add_subplot(334)
            ax5 = fig.add_subplot(335)
            ax6 = fig.add_subplot(336)
            ax7 = fig.add_subplot(337)
            ax8 = fig.add_subplot(338)
            ax9 = fig.add_subplot(339)
            # 需要注意的是如果dataset中原本数据个数(不是repeat的个数)不能整除BATCH_SIZE会使最后一个batch数据个数少于BATCH_SIZE,这对训练没影响,对本例子打印、展示图像时会报错
            img_val,label_val = sess.run([img,label])
            print('label: ', label_val[0],label_val[1],label_val[2],label_val[3],label_val[4],label_val[5],label_val[6],label_val[7],label_val[8])
            ax1.imshow(img_val[0])
            ax2.imshow(img_val[1])
            ax3.imshow(img_val[2])
            ax4.imshow(img_val[3])
            ax5.imshow(img_val[4])
            ax6.imshow(img_val[5])
            ax7.imshow(img_val[6])
            ax8.imshow(img_val[7])
            ax9.imshow(img_val[8])
            plt.show()
        except tf.errors.OutOfRangeError:
            break
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值