tfrecords文件是tensorflow规范的数据文件。TensorFlow提供了TFRecord的格式来统一存储数据,TFRecord格式是一种将图像数据和标签放在一起的二进制文件,能更好的利用内存,在tensorflow中快速的复制,移动,读取,存储 等等。
TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。我们可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter 写入到TFRecords文件。
最近使用到了tfrecord文件,遇见很多坑,记录下来。
这里分为读取txt文件、生成tfrecord、读取tfrecord文件三个部分讲解
import tensorflow as tf
import numpy as np
import glob
import cv2
import os
import matplotlib.pyplot as plt
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'#去'warning'
#shuffle_data = True
image_path = "E:/project/posCNN_20180724/orignalImg/"
#取得该路径下所有图片的路径,type(addrs)= list
addrs = glob.glob(image_path) #标签数据的获得具体情况具体分析,type(labels)= list
txt_path = 'E:/project/posCNN_20180724/orignalImg/pos_train.txt'
train_tfrecord_path = 'E:/project/posCNN_20180724/dataset/train.tfrecords' # 输出文件地址
1、把txt文件读为列表
我们读取的txt文件是下面的这样的:
第一列是文件名,后面的三列数字是标签(这个根据所需要的标签数量觉定),一行一个样本,每次读入一行
#读取标签txt文件
def readLabelsTxt():
imgName = []
valueX = []
valueY = []
valueR = []
with open(txt_path, 'r') as f:
for line in f:
key, value1, value2, value3 = line.split()
imgName.append(key)
valueX.append(value1)
valueY.append(value2)
valueR.append(value3)
# print(imgCount, imgName[imgCount], valueX[imgCount], valueY[imgCount], valueR[imgCount])
print('dataset samples num:', len(imgName))
f.close()
return imgName, valueX, valueY, valueR
2、生成tfrecord
def load_image(addr): # A function to Load image
img = cv2.imread(addr)
# img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_CUBIC)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# img = img / 255. #归一化到[0,1]
# img = img.astype(np.float32)
return img
# 将数据转化成对应的属性
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
# 把数据写入TFRecods文件
def create_tfrecord(imgName, posX, posY, posR):
writer = tf.python_io.TFRecordWriter(train_tfrecord_path)# 创建一个writer来写TFRecords文件
for i in range(len(imgName)):
img_path = image_path + imgName[i]
print(i,' Path:',img_path)
# img = cv2.imread(img_path)
img = load_image(img_path)
img = img.astype(np.uint8)
img_raw = img.tostring()#img.tobytes()
example = tf.train.Example(features=tf.train.Features(
feature={
'img_raw': _bytes_feature(img_raw),
'img_posX': _int64_feature(int(posX[i])),
'img_posY': _int64_feature(int(posY[i])),
'img_posR': _int64_feature(int(posR[i]))
}))
writer.write(example.SerializeToString())
writer.close()
从对应路径加载图片再和标签一起生成字典
3、读取tfrecord文件
def read_and_decode(is_train):
# if FLAGS.IS_TRAIN == True:
# filename_queue = tf.train.string_input_producer([FLAGS.tfrecord_filename_train], shuffle = True)
# if FLAGS.IS_TRAIN == False:######num_epochs need to modify
# print 'test_dataset'
filename_queue = tf.train.string_input_producer([train_tfrecord_path], shuffle = False)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features={
'img_raw': tf.FixedLenFeature([], tf.string),
'img_posX': tf.FixedLenFeature([], tf.int64),
'img_posY': tf.FixedLenFeature([], tf.int64),
'img_posR': tf.FixedLenFeature([], tf.int64)
})
images = tf.decode_raw(features['img_raw'], tf.uint8)
images = tf.reshape(images, [1200, 1600, 3])
img_posX = features['img_posR']
img_posY = features['img_posY']
img_posR = features['img_posR']
# img_posX = tf.cast(features['img_posX'], tf.int32)
# img_posY = tf.cast(features['img_posY'], tf.int32)
# img_posR = tf.cast(features['img_posR'], tf.int32)
if is_train == True:
img_raw, labelX, labelY, labelR = tf.train.shuffle_batch([images, img_posX, img_posY, img_posR],
batch_size = 1,
capacity = 3,
min_after_dequeue = 3)
else:
img_raw, labelX, labelY, labelR = tf.train.batch([images, img_posX, img_posY, img_posR],
batch_size = 1,
capacity = 3)
return img_raw, labelX, labelY, labelR
最后我们运行程序,可以看见在我们规定的文件夹内有了一个。tfrecords文件,通过函数读取文件,可以把图片和对应标签显示出来,说明写入和读取的过程是对的。
if __name__ == '__main__':
# imgName, posX, posY, posR = readLabelsTxt()
# create_tfrecord(imgName = imgName, posX = posX, posY = posY, posR = posR)
image, X, Y, R = read_and_decode(is_train = False)
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
## 启动多线程处理输入数据
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
a, b, c, d = sess.run([image, X, Y, R])
print(b, c, d)
coord.request_stop()
coord.join(threads)
aa = np.uint8(a[0,:, :,:])
plt.imshow(aa)
plt.show()
参考:http://www.360doc.com/content/17/0611/21/42392246_661965445.shtml