.tfrecords训练文件的生成
import os
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import cv2
def extract_image(filename, resize_height, resize_width):
image = cv2.imread(filename)
image = cv2.resize(image, (resize_height, resize_width))
b,g,r = cv2.split(image)
rgb_image = cv2.merge([r,g,b])
return rgb_image
if __name__ == '__main__':
cwd = '/media/digta/tfrecord/test/'
classes = ['0000045', '0000099']
writer = tf.python_io.TFRecordWriter("/media/digta/tfrecord/test.tfrecords")
for index, name in enumerate(classes):
class_path = cwd + name + '/'
for img_name in os.listdir(class_path):
img_path = class_path + img_name
image = extract_image(img_path, 128, 128)
img_raw = image.tobytes()
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
writer.write(example.SerializeToString())
writer.close()
读取生成的.tfrecords文件
import tensorflow as tf
import numpy as np
from PIL import Image
import time
import cv2
def read_batch(record_path,batch_size):
if not tf.gfile.Exists(record_path):
raise ValueError('Failed to find file: ' + record_path)
filename_queue = tf.train.string_input_producer([record_path])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string),
})
image = tf.reshape(tf.decode_raw(features['img_raw'], tf.uint8), [128, 128, 3])
label = tf.cast(features['label'], tf.int32)
mean_img = tf.image.per_image_standardization(image)
images, labels, mean_imgs = tf.train.shuffle_batch([image, label, mean_img],
batch_size=batch_size,
num_threads=64,
capacity=1000 + 3 * batch_size,
min_after_dequeue=1000
)
return images, labels, mean_imgs
if __name__ == '__main__':
tf_record_path = '/media/digta/tfrecord/test/test.tfrecords'
raw_images, labels, mean_imgs = read_batch(tf_record_path, 32)
with tf.Session() as sess:
init_op = tf.initialize_all_variables()
sess.run(init_op)
coord=tf.train.Coordinator()
threads= tf.train.start_queue_runners(coord=coord)
for i in range(1):
example, l, mImage = sess.run([raw_images, labels, mean_imgs])
for idx in range(0, 32):
img = Image.fromarray(mImage[idx], 'RGB')
Image._show(img)
time.sleep(0.5)
coord.request_stop()
coord.join(threads)