1,生成train.tfrecords的数据,gen_data.py
import os
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
path = r"D:\Deep_Learning_data\cyclegan\apple2orange" # apple(苹果) testA, orange(橘子) testB
classes = {'testA', 'testB'}
writer = tf.python_io.TFRecordWriter('train.tfrecords') # 要生成的文件
for index, name in enumerate(classes):
class_path = path + "\\" + name + "\\"
for img_name in os.listdir(class_path):
img_path = os.path.join(class_path, img_name)
img = Image.open(img_path)
img = img.resize((256,256))
img_raw = img.tobytes() # 将图片转化为二进制格式
example = tf.train.Example(features=tf.train.Features(feature={
'label':tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'img_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
writer.write(example.SerializeToString())
writer.close()
def read_and_decode(filename): # 读取tfrecords数据
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features={
'label':tf.FixedLenFeature([], tf.int64),
'img_raw':tf.FixedLenFeature([], tf.string)
}) # 将image数据和label取出来
img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img, [256,256,3])
img = tf.cast(img, tf.float32)
label = tf.cast(features['label', tf.int32])
return img, label
2,读取tfrecord格式的数据,read_data.py
import os
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
file_queue = tf.train.string_input_producer(['train.tfrecords'])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(file_queue)
features = tf.parse_single_example(serialized_example,
features={
'label':tf.FixedLenFeature([], tf.int64),
'img_raw':tf.FixedLenFeature([], tf.string)
})
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [256,256,3])
label = tf.cast(features['label'], tf.int32)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(10):
print()
example, l = sess.run([image, label])
img = Image.fromarray(example, 'RGB')
save_path = os.getcwd() +"\\" + str(i) + '_''Label_' + str(l) + ".jpg"
print(save_path)
img.save(save_path)
print(example.shape,l.shape)
coord.join(threads)