import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
# The following functions can be used to convert a value to a type compatible with tf.Example.
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""Returns a float_list from a float(float32) / double(float64)."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int32 / uint32 / int64 / uint64."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# Create a dictionary with features that may be relevant.
# 高版本的TensorFlow支持解码后直接获取shape
def image_example(img_raw, label):
img_tensor = tf.image.decode_jpeg(img_raw)
image_shape = img_tensor.shape
feature = {
'height': _int64_feature(image_shape[0]),
'width': _int64_feature(image_shape[1]),
'depth': _int64_feature(image_shape[2]),
'label': _int64_feature(label),
'image_raw': _bytes_feature(img_raw),
}
return tf.train.Example(features=tf.train.Features(feature=feature))
# Create a dictionary with features that may be relevant.
# 低版本的TensorFlow不支持解码后直接获取shape,转成numpy.ndarray后在获取
def image_example_sess(img_raw, label, sess):
img_tensor = tf.image.decode_jpeg(img_raw)
with sess.as_default():
img_data = img_tensor.eval()
#print(type(img_data))
image_shape = img_data.shape
feature = {
'height': _int64_feature(image_shape[0]),
'width': _int64_feature(image_shape[1]),
'depth': _int64_feature(image_shape[2]),
'label': _int64_feature(label),
'image_raw': _bytes_feature(img_raw),
}
return tf.train.Example(features=tf.train.Features(feature=feature))
'''
img_raw = tf.gfile.FastGFile('test1.jpg', 'rb').read()
label = 0
with tf.Session() as sess:
print(image_example_sess(img_raw, label, sess))
'''
#####################################################################################################
# 写入 TFRecord 文件
# Write the raw image files to `images.tfrecords`.
# First, process the two images into `tf.Example` messages.
# Then, write to a `.tfrecords` file.
image_labels = {'1.jpg' : 0, '2.jpg': 0, '3.jpg' : 1, '4.jpg': 1,'5.jpg' : 2, '6.jpg': 2,'7.jpg' : 3, '8.jpg': 3,'9.jpg' : 4, '10.jpg': 4}
record_file = 'images.tfrecords'
with tf.Session() as sess:
with tf.io.TFRecordWriter(record_file) as writer:
for filename, label in image_labels.items():
img_raw = tf.gfile.FastGFile(filename, 'rb').read()
tf_example = image_example_sess(img_raw, label, sess)
writer.write(tf_example.SerializeToString())
def _parse_image_function(example_proto):
# Create a dictionary describing the features.
image_feature_description = {
'height': tf.io.FixedLenFeature([], tf.int64),
'width': tf.io.FixedLenFeature([], tf.int64),
'depth': tf.io.FixedLenFeature([], tf.int64),
'label': tf.io.FixedLenFeature([], tf.int64),
'image_raw': tf.io.FixedLenFeature([], tf.string),
}
# Parse the input tf.Example proto using the dictionary above.
return tf.io.parse_single_example(example_proto, image_feature_description) # 返回值是<class 'dict'>类型
#####################################################################################################
# 读取 TFRecord 文件
input_files = ['images.tfrecords'] # 可以有多个文件
raw_image_dataset = tf.data.TFRecordDataset(input_files)
parsed_image_dataset = raw_image_dataset.map(_parse_image_function)
iterator = parsed_image_dataset.make_one_shot_iterator()
feature_dict = iterator.get_next()
with tf.Session() as sess:
for i in range(len(image_labels)):
feature_dict_val = sess.run(feature_dict)
#print(feature_dict['height'])
#print('height: ', feature_dict_val['height'])
#print('width: ', feature_dict_val['width'])
#print('depth: ', feature_dict_val['depth'])
#print('label: ', feature_dict_val['label'])
img = tf.io.decode_image(feature_dict_val['image_raw']).eval()
#plt.imshow(img)
#plt.show()
#####################################################################################################
# 读取 TFRecord 文件,文件路径由placeholder提供
input_files = tf.placeholder(tf.string)
dataset = tf.data.TFRecordDataset(input_files)
dataset = dataset.map(_parse_image_function)
# 定义遍历dataset的initializable_iterator()
iterator = dataset.make_initializable_iterator()
feature_dict = iterator.get_next()
with tf.Session() as sess:
sess.run(iterator.initializer,feed_dict={input_files : ['images.tfrecords', 'images.tfrecords']})
# 遍历所有数据一个epoch,遍历结束时抛出OutOfRangeError,因为在动态指定输入数据时不同数据来源的数据量大小未知,
# 该方法使得不必提前知道数据量的确切大小
while True:
try:
feature_dict_val = sess.run(feature_dict)
#print('height: ', feature_dict_val['height'])
#print('width: ', feature_dict_val['width'])
#print('depth: ', feature_dict_val['depth'])
#print('label: ', feature_dict_val['label'])
img = tf.io.decode_image(feature_dict_val['image_raw']).eval()
#plt.imshow(img)
#plt.show()
except tf.errors.OutOfRangeError:
break
# 给定一张图像,随机调整图像的色彩,由于调整亮度、对比度、饱和度、色相的顺序会影响最后的结果,所以可以定义多种不同的顺序
# 具体使用哪种顺序可以在训练数据进行预处理时随机选择
def distort_color(img, color_ordering=0):
if 0 == color_ordering:
img = tf.image.random_brightness(img, max_delta=32./255.)
img = tf.image.random_saturation(img,lower=0.5,upper=1.5)
img = tf.image.random_hue(img,max_delta=0.2)
img = tf.image.random_contrast(img, lower=0.5,upper=1.5)
elif 1 == color_ordering:
img = tf.image.random_contrast(img, lower=0.5,upper=1.5)
img = tf.image.random_saturation(img,lower=0.5,upper=1.5)
img = tf.image.random_hue(img,max_delta=0.2)
img = tf.image.random_brightness(img, max_delta=32./255.)
# 还可以定义很多……
#elif 2 == color_ordering:
return tf.clip_by_value(img, 0.0, 1.0) # 防止颜色变换之后值域超出[0.0,1.0]因此需要做个处理
# 给定一个解析后的example proto数据,及目标尺寸、图像标注框,此函数给出预处理,函数输入是图像识别问题中的原始训练图像,
# 输出是神经网络模型的输入层,只需处理训练数据,对于预测数据不使用随机变换等步骤
# feature_dict类型是函数tf.io.parse_single_example()的返回值类型'A dict mapping feature keys to Tensor and SparseTensor values.'
# https://tensorflow.google.cn/versions/r1.15/api_docs/python/tf/io/parse_single_example
def preprocess_for_train(feature_dict, resized_h, resized_w, boundbox):
# 原始图像信息
img_h = feature_dict['height']
img_w = feature_dict['width']
img_ch = feature_dict['depth']
label = feature_dict['label']
img_raw = feature_dict['image_raw']
# 解码图像
img_data = tf.io.decode_jpeg(img_raw)
if img_data.dtype != tf.float32:
img_data = tf.image.convert_image_dtype(img_data, dtype=tf.float32)
# 如果没有标注框,则需要关注整幅图像
if boundbox is None:
boundbox = tf.constant([0.0,0.0,1.0,1.0],dtype=tf.float32,shape=[1,1,4])
begin,size,bboxes = tf.image.sample_distorted_bounding_box(image_size=[img_h, img_w, img_ch],bounding_boxes=boundbox,min_object_covered=0.5)
distorted_img = tf.slice(img_data, begin, size)
# 将随机截取的图像调整为神经网络输入层的大小,resize方法是随机选择的
distorted_img = tf.image.resize_images(distorted_img,[resized_h, resized_w],method=np.random.randint(4))
# 随机左右翻转图像
distorted_img = tf.image.random_flip_left_right(distorted_img)
# 随机调整图像色彩
distorted_img = distort_color(distorted_img, np.random.randint(2))
return distorted_img, label
#####################################################################################################
PREDIFINE_HEIGHT = 120
PREDIFINE_WIDTH = 120
BATCH_SIZE = 9 # 数据batch大小
SHUFFLE_BUFFER = 10000 # 随机打乱数据时buffer的大小,越大随机效果越好但内存占用也越大
REPEAT_NUM = 5 # 将数据集重复的次数,相当于进行多次epoch,比如数据原来共10个,repeat之后变成50个
input_files = ['images.tfrecords'] # 可以有多个文件
dataset = tf.data.TFRecordDataset(input_files)
dataset = dataset.map(_parse_image_function)
# 此处表示对dataset中每个元素进行preprocess_for_train操作,此时每个元素类型是函数tf.io.parse_single_example()的返回类型即<class 'dict'>
# 也就是说经map操作后dataset中元素类型是变化的
# 处理过后dataset中每个元素类型变为函数preprocess_for_train()的返回值类型,即distorted_img, label的类型
dataset = dataset.map(lambda feat_dict: preprocess_for_train(feat_dict,PREDIFINE_HEIGHT,PREDIFINE_WIDTH,None))
dataset = dataset.shuffle(SHUFFLE_BUFFER).batch(BATCH_SIZE)
dataset = dataset.repeat(REPEAT_NUM)
iterator = dataset.make_one_shot_iterator()
img,label = iterator.get_next() # img和label都是BATCH_SIZE个数据
with tf.Session() as sess:
while True:
try:
fig = plt.figure()
ax1 = fig.add_subplot(331)
ax2 = fig.add_subplot(332)
ax3 = fig.add_subplot(333)
ax4 = fig.add_subplot(334)
ax5 = fig.add_subplot(335)
ax6 = fig.add_subplot(336)
ax7 = fig.add_subplot(337)
ax8 = fig.add_subplot(338)
ax9 = fig.add_subplot(339)
# 需要注意的是如果dataset中原本数据个数(不是repeat的个数)不能整除BATCH_SIZE会使最后一个batch数据个数少于BATCH_SIZE,这对训练没影响,对本例子打印、展示图像时会报错
img_val,label_val = sess.run([img,label])
print('label: ', label_val[0],label_val[1],label_val[2],label_val[3],label_val[4],label_val[5],label_val[6],label_val[7],label_val[8])
ax1.imshow(img_val[0])
ax2.imshow(img_val[1])
ax3.imshow(img_val[2])
ax4.imshow(img_val[3])
ax5.imshow(img_val[4])
ax6.imshow(img_val[5])
ax7.imshow(img_val[6])
ax8.imshow(img_val[7])
ax9.imshow(img_val[8])
plt.show()
except tf.errors.OutOfRangeError:
break
如何使用TensorFlow进行训练数据的准备?
最新推荐文章于 2024-01-21 13:54:38 发布