import tensorflow as tf
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
import random
def get_example_nums(tf_records_filenames):
nums= 0
for record in tf.python_io.tf_record_iterator(tf_records_filenames):
nums+=1
return nums
def show_image(title,image):
plt.imshow(image)
plt.axis('on')
plt.title(title)
plt.show()
def load_labels_file(filename,labels_num=1,shuffle=False):
images=[]
labels=[]
with open(filename) as f:
lines_list =f.readlines()
if shuffle:
random.shuffle(lines_list)
for lines in lines_list:
line =lines.rstrip().split(' ')
label=[]
for i in range(labels_num):
label.append(int(line[i+1]))
images.append(line[0])
labels.append(label)
return images,labels
def read_image(filename,resize_height,resize_width,normalization=False):
bgr_image =cv2.imread(filename)
if len(bgr_image.shape)==2:
print("warning:gray image",filename)
bgr_image =cv2.cvtColor(bgr_image,cv2.COLOR_GRAY2BGR)
rgb_image=cv2.cvtColor(bgr_image,cv2.COLOR_BGR2RGB)
if resize_height>0 and resize_height>0:
rgb_image =cv2.resize(rgb_image,(resize_width,resize_height))
rgb_image =np.asanyarray(rgb_image)
if normalization:
rgb_image=rgb_image/255.0
return rgb_image
def get_batch_images(images,labels,batch_size,labels_nums,one_hot=False,shuffle=False,num_threads=1):
min_after_dequeue = 200
capacity =min_after_dequeue+3*batch_size
if shuffle:
images_batch,labels_batch =tf.train.shuffle_batch([images,labels],batch_size=batch_size,
capacity=capacity,
min_after_dequeue=min_after_dequeue,
num_threads=num_threads)
else:
image_batch,labels_batch =tf.train.batch([images,labels],
batch_size=batch_size,
capacity=capacity,
num_threads=num_threads)
if one_hot:
labels_batch =tf.one_hot(labels_batch,labels_nums,1,0)
return image_batch,labels_batch
def read_record(filename,resize_height,resize_width,type=None):
filename_queue =tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_,serialized_example =reader.read(filename_queue)
features =tf.parse_single_example(serialized_example,
features={
'image_raw':tf.FixedLenFeature([],tf.string),
'height':tf.FixedLenFeature([],tf.int64),
'width':tf.FixedLenFeature([],tf.int64),
'depth':tf.FixedLenFeature([],tf.int64),
'label':tf.FixedLenFeature([],tf.int64)}
)
tf_image =tf.decode_raw(features['image_raw'],tf.uint8)
#tf_height =features['height']
#tf_width =features['width']
#tf_depth =features['depth']
tf_label =tf.cast(features['label'],tf.int32)
tf_image =tf.reshape(tf_image,[resize_height,resize_width,3])
#tf_image =tf.image.resize_images(tf_image,[199,199])
if type is None:
tf_image =tf.cast(tf_image,tf.float32)
elif type =='normalizetion':
# 仅当输入数据是uint8,才会归一化[0,255]
# tf_image = tf.cast(tf_image, dtype=tf.uint8)
# tf_image = tf.image.convert_image_dtype(tf_image, tf.float32)
tf_image =tf.cast(tf_image,tf.float32)*(1./255.0)
elif type =='standardization':
tf_image =tf.cast(tf_image,tf.float32)*(1./255.0)-0.5
return tf_image,tf_label
def create_records(image_dir,file,output_record_dir,resize_height,resize_width,shuffle,log=5):
images_list,labels_list =load_labels_file(file,1,shuffle)
writer =tf.python_io.TFRecordWriter(output_record_dir)
for i ,[image_name,labels] in enumerate(zip(images_list,labels_list)):
image_path = os.path.join(image_dir,images_list[i])
if not os.path.exists(image_path):
print('Err:no image',image_path)
continue
image =read_image(image_path,resize_height,resize_width)
image_raw =image.tostring()
if i%log ==0 or i ==len(images_list)-1:
print('-------------processing:%d--------'%(i))
print('current image_path=%s' % (image_path),'shape:{}'.format(image.shape),'labels:{}'.format(labels))
label=labels[0]
example = tf.train.Example(features=tf.train.Features(feature={
'image_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])),
'height': tf.train.Feature(int64_list=tf.train.Int64List(value=[image.shape[0]])),
'width': tf.train.Feature(int64_list=tf.train.Int64List(value=[image.shape[1]])),
'depth': tf.train.Feature(int64_list=tf.train.Int64List(value=[image.shape[2]])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
}))
writer.write(example.SerializeToString())
writer.close()
def disp_records(record_file,resize_heighe,resize_width,show_nums=4):
tf_image,tf_label =read_record(record_file,resize_heighe,resize_width,type ='normalization')
init_op =tf.initialize_all_variable()
with tf.Session() as sess:
sess.run(init_op)
coord =tf.train.Coordonator()
threads =tf.train.start_queue_runners(sess=sess,coord =coord)
for i in range(show_nums):
image,label =sess.run([tf_image,tf_label])
# image = tf_image.eval()
# 直接从record解析的image是一个向量,需要reshape显示
# image = image.reshape([height,width,depth])
print('shape:{},tpye:{},labels:{}'.format(image.shape,image.dtype,label))
# pilimg = Image.fromarray(np.asarray(image_eval_reshape))
# pilimg.show()
show_image("image:%d"%(label),image)
coord.request_stop()
coord.join(threads)
def batch_test(record_file,resize_height, resize_width):
'''
:param record_file: record文件路径
:param resize_height:
:param resize_width:
:return:
:PS:image_batch, label_batch一般作为网络的输入
'''
# 读取record函数
tf_image,tf_label = read_record(record_file,resize_height,resize_width,type='normalization')
image_batch, label_batch= get_batch_images(tf_image,tf_label,batch_size=4,labels_nums=5,one_hot=False,shuffle=False)
init = tf.global_variables_initializer()
with tf.Session() as sess: # 开始一个会话
sess.run(init)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(4):
# 在会话中取出images和labels
images, labels = sess.run([image_batch, label_batch])
# 这里仅显示每个batch里第一张图片
show_image("image", images[0, :, :, :])
print('shape:{},tpye:{},labels:{}'.format(images.shape,images.dtype,labels))
# 停止所有线程
coord.request_stop()
coord.join(threads)
if __name__ == '__main__':
# 参数设置
resize_height = 199 # 指定存储图片高度
resize_width = 199 # 指定存储图片宽度
shuffle=True
log=5
# 产生train.record文件
image_dir='E:/learning/musemart/dataset_updated/training_set'
train_labels = 'E:/learning/musemart/dataset_updated/training_set/art.txt' # 图片路径
train_record_output = 'E:/learning/musemart/dataset_updated/training_set/train.tfrecords'
create_records(image_dir,train_labels, train_record_output, resize_height, resize_width,shuffle,log)
train_nums=get_example_nums(train_record_output)
print("save train example nums={}".format(train_nums))
# 产生val.record文件
#image_dir='dataset/val'
#val_labels = 'dataset/val.txt' # 图片路径
#val_record_output = 'dataset/record/val.tfrecords'
#create_records(image_dir,val_labels, val_record_output, resize_height, resize_width,shuffle,log)
#val_nums=get_example_nums(val_record_output)
#print("save val example nums={}".format(val_nums))
# 测试显示函数
# disp_records(train_record_output,resize_height, resize_width)
batch_test(train_record_output,resize_height, resize_width)
TFrecords的生成和读取
最新推荐文章于 2024-01-29 08:58:49 发布