项目详细请猛戳我的github地址,直接可运行:https://github.com/SamXiaosheng/create-tfRecord
下面是main文件代码和create tfRecord文件:
import tensorflow as tf from tfRecord import * import cv2 FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string('image_dir', './image/', """Directory where to write event logs """) def main(_): create_tfrecords(FLAGS.image_dir) image_batch,label_batch =read_and_decode('test.tfRecord') with tf.Session() as sess: coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) while not coord.should_stop(): image,label = sess.run([image_batch,label_batch]) print(label) cv2.imshow('image',image[0]) cv2.waitKey(200) coord.request_stop() coord.join(threads) if __name__ == '__main__': tf.app.run()
import tensorflow as tf import numpy as np import os import cv2 def read_and_decode(filename): filename_queue = tf.train.string_input_producer([filename]) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example(serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw' : tf.FixedLenFeature([], tf.string), }) img = tf.decode_raw(features['img_raw'], tf.uint8)#这里的格式非常重要 img = tf.reshape(img, [227, 227, 3]) #img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 label = tf.cast(features['label'], tf.uint8) image_batch, label_batch = tf.train.shuffle_batch([img, label], batch_size=1,#这里参数设置目的是每次只读取一个样本 capacity=1, min_after_dequeue=0) #label_batch = tf.one_hot(label_batch, NUM_CLASSES) #label_batch = tf.cast(label_batch, dtype=tf.int64) #label_batch = tf.reshape(label_batch, [batch_size, NUM_CLASSES]) return image_batch, label_batch #读取某目录路径下的所有文件,返回图片的名称列表 def dirtomdfbatchmsra(dirpath):#读取目录下训练图像和对应的label image_ext = 'jpg' images = [fn for fn in os.listdir(dirpath) if fn.endswith(image_ext)]#返回dirpath路径下所有后缀jpg文件 images.sort()#排序的目的有利于样本和标签的对应 #print(images) gt_ext = 'png' gt_maps = [fn for fn in os.listdir(dirpath) if fn.endswith(gt_ext)] gt_maps.sort() #print(gt_maps) return gt_maps,images#返回gt图和训练image的所有文件名 def create_tfrecords(image_dir): writer = tf.python_io.TFRecordWriter("test.tfRecord") image_png,image_jpg = dirtomdfbatchmsra(image_dir) for index, name in enumerate(image_jpg): img = cv2.imread(image_dir+name).astype(np.uint8) img = cv2.resize(img,(227,227))#统一大小 img_raw = img.tobytes()#转换成字节形式 example = tf.train.Example(features=tf.train.Features(feature={ "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])), 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])) })) writer.write(example.SerializeToString()) for index, name in enumerate(image_png): img = cv2.imread(image_dir+name).astype(np.uint8) img = cv2.resize(img,(227, 227))# img_raw = img.tobytes() example = tf.train.Example(features=tf.train.Features(feature={ "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])), 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])) })) writer.write(example.SerializeToString()) writer.close()