tensorflow下TFRecord的读与写

读:

TFRecordRead:

import tensorflow as tf

#读取train.tfrecord中的数据
def read_and_decode(filename):
    #创建一个reader来读取TFRecord文件中的样例
    reader = tf.TFRecordReader()
    #创建一个队列来维护输入文件列表
    file_queue = tf.train.string_input_producer([filename], shuffle=False, num_epochs=1)
    #从文件读取一个样例,也可以使用read_up_to读取多个样例
    _, serialized_example = reader.read(file_queue)
    #解析读入的样例。如果需要解析多个样例,可以使用parse_example函数
    features = tf.parse_single_example(
        serialized_example,
        features={
            #Tensorflow提供两种不同的属性解析方法:
            # 一种是tf.train.FixedLenFeature,这种方法解析的结果是一个Tensor。
            #另一种是tf.VarLenFeature,这种方法得到的解析结果为SparseTensor,用于处理稀疏数据
            'label': tf.FixedLenFeature([], tf.int64),
            'image_raw': tf.FixedLenFeature([], tf.string)})

    #将字符串解析成图相对应的像素数组
    img = tf.decode_raw(features['image_raw'], tf.uint8)
    img = tf.reshape(img, [224, 224, 3])
    img = tf.image.per_image_standardization(img)
    labels = tf.cast(features['label'], tf.int32)
    return img, labels

#将图像打包 形成batch
def createBatch(filename, batchsize):
    images, labels = read_and_decode(filename)
    min_after_dequeue = 10
    capacity = min_after_dequeue + 3 * batchsize

    image_batch, label_batch = tf.train.shuffle_batch([images, labels],
                                                      batch_size=batchsize,
                                                      capacity=capacity,
                                                      min_after_dequeue=min_after_dequeue)
    label_batch = tf.one_hot(label_batch, depth=2)
    return image_batch, label_batch


if __name__ == "__main__":
    mapfile = '/xxxx/xxx/x.txt'
    train_filename = '/xxxxxx/xxx/xx/train.tfrecords'
    test_filename = '/xxxx/xxxx/xxx/test.tfrecords'

    image_batch, label_batch = createBatch(filename=train_filename, batchsize=2) #两幅图作为一个batch进行训练
    test_images, test_labels = createBatch(filename=test_filename, batchsize=20)
    with tf.Session as sess:
        initop = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
        sess.run(initop)
        #启动多线程处理数据
        coord = tf.train.Coordinator()
        thread = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            step = 0
            while 1:
                _image_batch, _label_batch = sess.run([image_batch, label_batch])
                step += 1
                print(step)
                print(_label_batch)
        except tf.errors.OutOfRangeError:
            print("train data done!")

        try:
            step = 0
            while 1:
                _test_images, _test_labels = sess.run([test_images, test_labels])
                step += 1
                print(step)
                print(_test_labels)
        except tf.errors.OutOfRangeError:
            print("TESE done!")

        coord.request_stop()
        coord.join(thread)

 

写:

TFRecordWrite:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import os
from PIL import Image
#生成整数型的属性
def _int64_features(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

#生成字符串型的属性
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

#制作TFRecord格式
def createTFRecord(filename, mapfile):
    class_map = {}
    #数据源地址
    data_dir = '/home/zsy/PycharmProjects/testTFR/'
    #数据文件夹
    classes = {'class1', 'class2'}

    #输出tf格式的文件地址,filename就是传入的地址
    writer = tf.python_io.TFRecordWriter(filename)

    for index, name in enumerate(classes):
        class_path = data_dir + name + '/'
        class_map[index] = name

        for img_name in os.listdir(class_path):
            img_path = class_path + img_name
            img = Image.open(img_path)
            img = img.resize((224, 224))
            img_row = img.tobytes()  #将图片转化为二进制格式
            #将一个样例转化为Example Protocol Buffer,并将所有信息写入这个数据结构
            example = tf.train.Example(features=tf.train.Feature(feature={
                'label':_int64_features(index),
                'image_raw':_bytes_feature(img_row)}))

            writer.write(example.SerializeToString())
    writer.close()


    textfile = open(mapfile, 'w+')
    for key in class_map.keys():
        textfile.writelines(str(key) + '.' + class_map[key] + '\n') #构造一个txt 存放所有图片名字
    textfile.close()


 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值