tfrecorder的制作/读取(附实战代码)

import tensorflow as tf  
import numpy  

def write_binary():  
    writer = tf.python_io.TFRecordWriter('data.tfrecord')  
    #创建example  
    for i in range(0, 100):  
        a = 0.618 + i  
        b = [2016 + i, 2017+i]  
        c = numpy.array([[0, 1, 2],[3, 4, 5]]) + i  
        c = c.astype(numpy.uint8)  
        c_raw = c.tostring() #转化成字符串  
    #每个example的feature成员变量是一个dict,存储一个样本的不同部分(例如图像像素+类标)  
        example = tf.train.Example(  
            features=tf.train.Features(  
                feature={  
                    'a': tf.train.Feature(  
                        float_list=tf.train.FloatList(value=[a])  
                    ),  

                    'b': tf.train.Feature(  
                        int64_list=tf.train.Int64List(value=b)  
                    ),  
                    'c': tf.train.Feature(  
                        bytes_list=tf.train.BytesList(value=[c_raw])  
                    )  
                }  
            )  
        )  
    #序列化  
        serialized = example.SerializeToString()  
    #写入文件  
        writer.write(serialized)  
    writer.close()  

def read_single_sample(filename):  
    #创建文件队列,不限读取的数量  
    filename_queue = tf.train.string_input_producer([filename], num_epochs=None)  
    # create a reader from file queue  
    reader = tf.TFRecordReader()  
    #reader从文件队列中读入一个序列化的样本  
    _, serialized_example = reader.read(filename_queue)  

    # get feature from serialized example  
    #解析符号化的样本  
    features = tf.parse_single_example(  
        serialized_example,  
        features={  
            'a': tf.FixedLenFeature([], tf.float32),  
            'b': tf.FixedLenFeature([2], tf.int64),  
            'c': tf.FixedLenFeature([], tf.string)  
        }  
    )  
    a = features['a']  
    b = features['b']  
    c_raw = features['c']  
    c = tf.decode_raw(c_raw, tf.uint8)  
    c = tf.reshape(c, [2, 3])  
    return a, b, c  

# 
#write_binary()  
#else:  
# create tensor  
a, b, c = read_single_sample('data.tfrecord')  

a_batch, b_batch, c_batch = tf.train.shuffle_batch([a, b, c], batch_size=5, capacity=200, min_after_dequeue=100, num_threads=2)  

# sess  
sess = tf.Session()  
init = tf.initialize_all_variables()  
sess.run(init)  

tf.train.start_queue_runners(sess=sess)  

for step in range(3):  
    a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])  
    print(a_val, b_val, c_val)  

实战代码:

import tensorflow as tf  
import numpy  
import scipy.misc as misc
import os
import cv2
def write_binary():  
    cwd = os.getcwd()
    classes=['ym','zly','lyf']
    writer = tf.python_io.TFRecordWriter('data.tfrecord')  
    for index, name in enumerate(classes):
        class_path = os.path.join(cwd,name)
        for img_name in os.listdir(class_path):
            img_path = os.path.join(class_path , img_name)
            img = misc.imread(img_path)
            img1 = misc.imresize(img,[250,250,3])
            img_raw = img1.tobytes()              #将图片转化为原生bytes
            example = tf.train.Example(features=tf.train.Features(feature={
                    'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index]))}
                ))

    #序列化  
            serialized = example.SerializeToString()  
    #写入文件  
            writer.write(serialized)  
    writer.close()  



def read_and_decode(filename):  
    #创建文件队列,不限读取的数量  
    filename_queue = tf.train.string_input_producer([filename],shuffle=False)  
    # create a reader from file queue  
    reader = tf.TFRecordReader()  
    #reader从文件队列中读入一个序列化的样本  
    _, serialized_example = reader.read(filename_queue)  

    features = tf.parse_single_example(  
        serialized_example,  
        features={  
            'label': tf.FixedLenFeature([], tf.int64),  
            'img_raw': tf.FixedLenFeature([], tf.string) 

        }  
    )  
    img=tf.decode_raw(features['img_raw'],tf.uint8)
    img = tf.reshape(img, [250, 250, 3])
    label = tf.cast(features['label'], tf.int32)
    return img,label  


#write_binary()  


img,label = read_and_decode('data.tfrecord')  


img_batch, label_batch = tf.train.shuffle_batch([img,label], batch_size=18, capacity=200, min_after_dequeue=100, num_threads=2)  

# sess  
init = tf.global_variables_initializer() 
sess = tf.Session()  

sess.run(init)  
coord = tf.train.Coordinator()  
threads=tf.train.start_queue_runners(sess=sess,coord=coord)  

img, label = sess.run([img_batch, label_batch])
for i in range(18):   
    [b,g,r]=[cv2.split(img[i])[0],cv2.split(img[i])[1],cv2.split(img[i])[2]]
    cv2.imwrite('%d.png'%i,cv2.merge([r,g,b]))
coord.request_stop()
coord.join(threads)
sess.close()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值