tensorflow-001 制作、读取tfrecord文件的程序

tfrecords文件是tensorflow规范的数据文件。TensorFlow提供了TFRecord的格式来统一存储数据,TFRecord格式是一种将图像数据和标签放在一起的二进制文件,能更好的利用内存,在tensorflow中快速的复制,移动,读取,存储 等等。 
  TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。我们可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter 写入到TFRecords文件。 

    最近使用到了tfrecord文件,遇见很多坑,记录下来。

这里分为读取txt文件、生成tfrecord、读取tfrecord文件三个部分讲解


import tensorflow as tf
import numpy as np
import glob
import cv2
import os
import matplotlib.pyplot as plt

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'#去'warning'
#shuffle_data = True
image_path = "E:/project/posCNN_20180724/orignalImg/"
#取得该路径下所有图片的路径,type(addrs)= list
addrs = glob.glob(image_path) #标签数据的获得具体情况具体分析,type(labels)= list
txt_path = 'E:/project/posCNN_20180724/orignalImg/pos_train.txt'
train_tfrecord_path = 'E:/project/posCNN_20180724/dataset/train.tfrecords'  # 输出文件地址

1、把txt文件读为列表

我们读取的txt文件是下面的这样的:

第一列是文件名,后面的三列数字是标签(这个根据所需要的标签数量觉定),一行一个样本,每次读入一行

#读取标签txt文件
def readLabelsTxt():
    imgName = []
    valueX = []
    valueY = []
    valueR = []
    with open(txt_path, 'r') as f:
        for line in f:
            key, value1, value2, value3 = line.split()      
            imgName.append(key)
            valueX.append(value1)
            valueY.append(value2)
            valueR.append(value3)
#            print(imgCount, imgName[imgCount], valueX[imgCount], valueY[imgCount], valueR[imgCount])
    print('dataset samples num:', len(imgName))
    f.close()
    return imgName, valueX, valueY, valueR

2、生成tfrecord

def load_image(addr):  # A function to Load image
    img = cv2.imread(addr)
#    img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_CUBIC)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#    img = img / 255. #归一化到[0,1]
#    img = img.astype(np.float32)
    return img   

# 将数据转化成对应的属性
def _int64_feature(value):  
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))     
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 
def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

# 把数据写入TFRecods文件
def create_tfrecord(imgName, posX, posY, posR):
    writer = tf.python_io.TFRecordWriter(train_tfrecord_path)# 创建一个writer来写TFRecords文件    
    for i in range(len(imgName)):
        img_path = image_path + imgName[i]
        print(i,' Path:',img_path)
#        img = cv2.imread(img_path)
        img = load_image(img_path)
        img = img.astype(np.uint8)
        img_raw = img.tostring()#img.tobytes()
        example = tf.train.Example(features=tf.train.Features(
                     feature={
                    'img_raw': _bytes_feature(img_raw),
                    'img_posX': _int64_feature(int(posX[i])),
                    'img_posY': _int64_feature(int(posY[i])),
                    'img_posR': _int64_feature(int(posR[i]))
                      }))
        writer.write(example.SerializeToString())
    writer.close()    

从对应路径加载图片再和标签一起生成字典

3、读取tfrecord文件


def read_and_decode(is_train):
#    if FLAGS.IS_TRAIN == True:
#       filename_queue = tf.train.string_input_producer([FLAGS.tfrecord_filename_train], shuffle = True)
#    if FLAGS.IS_TRAIN == False:######num_epochs need to modify 
#       print 'test_dataset'
    filename_queue = tf.train.string_input_producer([train_tfrecord_path], shuffle = False)
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue) 
    features = tf.parse_single_example(serialized_example,
        features={
                'img_raw': tf.FixedLenFeature([], tf.string),
                'img_posX': tf.FixedLenFeature([], tf.int64),
                'img_posY': tf.FixedLenFeature([], tf.int64),
                'img_posR': tf.FixedLenFeature([], tf.int64)       
                })
    
    images = tf.decode_raw(features['img_raw'], tf.uint8)
    images = tf.reshape(images, [1200, 1600, 3])
    img_posX = features['img_posR']
    img_posY = features['img_posY']
    img_posR = features['img_posR']
#    img_posX = tf.cast(features['img_posX'], tf.int32)
#    img_posY = tf.cast(features['img_posY'], tf.int32)
#    img_posR = tf.cast(features['img_posR'], tf.int32)   
    if is_train == True:
       img_raw, labelX, labelY, labelR = tf.train.shuffle_batch([images, img_posX, img_posY, img_posR], 
                                                 batch_size = 1, 
                                                 capacity = 3, 
                                                 min_after_dequeue = 3)
    else:
        img_raw, labelX, labelY, labelR = tf.train.batch([images, img_posX, img_posY, img_posR], 
                                                 batch_size = 1, 
                                                 capacity = 3)
    return img_raw, labelX, labelY, labelR

最后我们运行程序,可以看见在我们规定的文件夹内有了一个。tfrecords文件,通过函数读取文件,可以把图片和对应标签显示出来,说明写入和读取的过程是对的。

if __name__ == '__main__':
#    imgName, posX, posY, posR = readLabelsTxt()
#    create_tfrecord(imgName = imgName, posX = posX, posY = posY, posR = posR)
    image, X, Y, R = read_and_decode(is_train = False)

    init_op = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init_op)
        ## 启动多线程处理输入数据
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        a, b, c, d = sess.run([image, X, Y, R])
        print(b, c, d)
        coord.request_stop()
        coord.join(threads)
        aa = np.uint8(a[0,:, :,:])
        plt.imshow(aa)
        plt.show()

参考:http://www.360doc.com/content/17/0611/21/42392246_661965445.shtml

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值