import tensorflow as tf
import numpy
def write_binary():
writer = tf.python_io.TFRecordWriter('data.tfrecord')
#创建example
for i in range(0, 100):
a = 0.618 + i
b = [2016 + i, 2017+i]
c = numpy.array([[0, 1, 2],[3, 4, 5]]) + i
c = c.astype(numpy.uint8)
c_raw = c.tostring() #转化成字符串
#每个example的feature成员变量是一个dict,存储一个样本的不同部分(例如图像像素+类标)
example = tf.train.Example(
features=tf.train.Features(
feature={
'a': tf.train.Feature(
float_list=tf.train.FloatList(value=[a])
),
'b': tf.train.Feature(
int64_list=tf.train.Int64List(value=b)
),
'c': tf.train.Feature(
bytes_list=tf.train.BytesList(value=[c_raw])
)
}
)
)
#序列化
serialized = example.SerializeToString()
#写入文件
writer.write(serialized)
writer.close()
def read_single_sample(filename):
#创建文件队列,不限读取的数量
filename_queue = tf.train.string_input_producer([filename], num_epochs=None)
# create a reader from file queue
reader = tf.TFRecordReader()
#reader从文件队列中读入一个序列化的样本
_, serialized_example = reader.read(filename_queue)
# get feature from serialized example
#解析符号化的样本
features = tf.parse_single_example(
serialized_example,
features={
'a': tf.FixedLenFeature([], tf.float32),
'b': tf.FixedLenFeature([2], tf.int64),
'c': tf.FixedLenFeature([], tf.string)
}
)
a = features['a']
b = features['b']
c_raw = features['c']
c = tf.decode_raw(c_raw, tf.uint8)
c = tf.reshape(c, [2, 3])
return a, b, c
#
#write_binary()
#else:
# create tensor
a, b, c = read_single_sample('data.tfrecord')
a_batch, b_batch, c_batch = tf.train.shuffle_batch([a, b, c], batch_size=5, capacity=200, min_after_dequeue=100, num_threads=2)
# sess
sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)
tf.train.start_queue_runners(sess=sess)
for step in range(3):
a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])
print(a_val, b_val, c_val)
实战代码:
import tensorflow as tf
import numpy
import scipy.misc as misc
import os
import cv2
def write_binary():
cwd = os.getcwd()
classes=['ym','zly','lyf']
writer = tf.python_io.TFRecordWriter('data.tfrecord')
for index, name in enumerate(classes):
class_path = os.path.join(cwd,name)
for img_name in os.listdir(class_path):
img_path = os.path.join(class_path , img_name)
img = misc.imread(img_path)
img1 = misc.imresize(img,[250,250,3])
img_raw = img1.tobytes() #将图片转化为原生bytes
example = tf.train.Example(features=tf.train.Features(feature={
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index]))}
))
#序列化
serialized = example.SerializeToString()
#写入文件
writer.write(serialized)
writer.close()
def read_and_decode(filename):
#创建文件队列,不限读取的数量
filename_queue = tf.train.string_input_producer([filename],shuffle=False)
# create a reader from file queue
reader = tf.TFRecordReader()
#reader从文件队列中读入一个序列化的样本
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string)
}
)
img=tf.decode_raw(features['img_raw'],tf.uint8)
img = tf.reshape(img, [250, 250, 3])
label = tf.cast(features['label'], tf.int32)
return img,label
#write_binary()
img,label = read_and_decode('data.tfrecord')
img_batch, label_batch = tf.train.shuffle_batch([img,label], batch_size=18, capacity=200, min_after_dequeue=100, num_threads=2)
# sess
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
coord = tf.train.Coordinator()
threads=tf.train.start_queue_runners(sess=sess,coord=coord)
img, label = sess.run([img_batch, label_batch])
for i in range(18):
[b,g,r]=[cv2.split(img[i])[0],cv2.split(img[i])[1],cv2.split(img[i])[2]]
cv2.imwrite('%d.png'%i,cv2.merge([r,g,b]))
coord.request_stop()
coord.join(threads)
sess.close()