#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon May 28 15:29:51 2018
"""
import tensorflow as tf
import numpy as np
from PIL import Image
import os
from scipy import misc
import time
temp_dir=os.getcwd()
files=os.listdir('ORL')
files_name=[temp_dir+'/ORL/'+x for x in files]
labels=[int(x.split('_')[-1].split('.')[0]) for x in files_name]
def create_record(files_name,labels,name):
writer=tf.python_io.TFRecordWriter(name+'.tfrecord')
for i in range(len(files_name)):
img=misc.imread(files_name[i])
img_raw=img.tobytes() #将图片转化为原生bytes
example=tf.train.Example(features=tf.train.Features(feature={
'imgs': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[labels[i]]))}))#example:img+label
serialized=example.SerializeToString() #序列化
writer.write(serialized) ##写入文件
writer.close()
def read_record(filename,h,w):
filename_quene=tf.train.string_input_producer([filename],shuffle=False)
train_reader=tf.TFRecordReader()
_,serialized_example=train_reader.read(filename_quene)
features=tf.parse_single_example(serialized_example,features={
'imgs': tf.FixedLenFeature([],tf.string),
'label': tf.FixedLenFeature([],tf.int64) })
img=tf.decode_raw(features['imgs'],tf.uint8)
img=tf.reshape(img,[h,w])
label=tf.cast(features['label'],tf.int32)
return img,label
create_record(files_name,labels,'train')
img,label=read_record('train.tfrecord',112,92)
img_batch,label_batch=tf.train.shuffle_batch([img,label],batch_size=20,capacity=200,min_after_dequeue=100,num_threads=6)
init=tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
coord=tf.train.Coordinator() #创建一个协调器,管理线程
threads=tf.train.start_queue_runners(sess=sess,coord=coord) #启动QueueRunner, 此时文件名队列已经进队
for i in range(200):
img,label=sess.run([img_batch,label_batch])
print(img)
print(label)
print('----------------------------',i,'--------------------------')
time.sleep(1)
coord.request_stop()
coord.join(threads)
在运行时报错
OutOfRangeError (see above for traceback): RandomShuffleQueue '_1_shuffle_batch/random_shuffle_queue' is closed and has insufficient elements (requested 10, current size 0)
经过检查发现是读入的图片是(112,92)的,自己reshape成了(112,92,3)的。
修改后无误。