TFrecords的生成和读取

import tensorflow as tf
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
import random


def get_example_nums(tf_records_filenames):
    nums= 0
    for record in tf.python_io.tf_record_iterator(tf_records_filenames):
        nums+=1
    return nums

def show_image(title,image):
    plt.imshow(image)
    plt.axis('on')
    plt.title(title)
    plt.show()

def load_labels_file(filename,labels_num=1,shuffle=False):
    images=[]
    labels=[]
    with open(filename) as f:
        lines_list =f.readlines()
        if shuffle:
            random.shuffle(lines_list)
        for lines in lines_list:
            line =lines.rstrip().split(' ')
            label=[]
            for i in range(labels_num):
                label.append(int(line[i+1]))
            images.append(line[0])
            labels.append(label)
    return images,labels

def read_image(filename,resize_height,resize_width,normalization=False):
    bgr_image =cv2.imread(filename)
    if len(bgr_image.shape)==2:
        print("warning:gray image",filename)
        bgr_image =cv2.cvtColor(bgr_image,cv2.COLOR_GRAY2BGR)
    rgb_image=cv2.cvtColor(bgr_image,cv2.COLOR_BGR2RGB)
    if resize_height>0 and resize_height>0:
        rgb_image =cv2.resize(rgb_image,(resize_width,resize_height))
    rgb_image =np.asanyarray(rgb_image)
    if normalization:
        rgb_image=rgb_image/255.0
    return rgb_image

def get_batch_images(images,labels,batch_size,labels_nums,one_hot=False,shuffle=False,num_threads=1):
    min_after_dequeue = 200
    capacity =min_after_dequeue+3*batch_size
    if shuffle:
        images_batch,labels_batch =tf.train.shuffle_batch([images,labels],batch_size=batch_size,
                                                          capacity=capacity,
                                                          min_after_dequeue=min_after_dequeue,
                                                          num_threads=num_threads)
    else:
        image_batch,labels_batch =tf.train.batch([images,labels],
                                                 batch_size=batch_size,
                                                 capacity=capacity,
                                                 num_threads=num_threads)
    if one_hot:
        labels_batch =tf.one_hot(labels_batch,labels_nums,1,0)
    return image_batch,labels_batch

def read_record(filename,resize_height,resize_width,type=None):
    filename_queue =tf.train.string_input_producer([filename])
    reader = tf.TFRecordReader()
    _,serialized_example =reader.read(filename_queue)
    features =tf.parse_single_example(serialized_example,
                                      features={
                                              'image_raw':tf.FixedLenFeature([],tf.string),
                                              'height':tf.FixedLenFeature([],tf.int64),
                                              'width':tf.FixedLenFeature([],tf.int64),
                                              'depth':tf.FixedLenFeature([],tf.int64),
                                              'label':tf.FixedLenFeature([],tf.int64)}
                                      )
    tf_image =tf.decode_raw(features['image_raw'],tf.uint8)
    #tf_height =features['height']
    #tf_width =features['width']
    #tf_depth =features['depth']
    tf_label =tf.cast(features['label'],tf.int32)
    
    tf_image =tf.reshape(tf_image,[resize_height,resize_width,3])
    #tf_image =tf.image.resize_images(tf_image,[199,199])
    if type is None:
        tf_image =tf.cast(tf_image,tf.float32)
    elif type =='normalizetion':
        # 仅当输入数据是uint8,才会归一化[0,255]
        # tf_image = tf.cast(tf_image, dtype=tf.uint8)
        # tf_image = tf.image.convert_image_dtype(tf_image, tf.float32)
        tf_image =tf.cast(tf_image,tf.float32)*(1./255.0)
    elif type =='standardization':
        tf_image =tf.cast(tf_image,tf.float32)*(1./255.0)-0.5
    return tf_image,tf_label

def create_records(image_dir,file,output_record_dir,resize_height,resize_width,shuffle,log=5):
    images_list,labels_list =load_labels_file(file,1,shuffle)
    
    writer =tf.python_io.TFRecordWriter(output_record_dir)
    for i ,[image_name,labels] in enumerate(zip(images_list,labels_list)):
        image_path = os.path.join(image_dir,images_list[i])
        if not os.path.exists(image_path):
            print('Err:no image',image_path)
            continue
        image =read_image(image_path,resize_height,resize_width)
        image_raw =image.tostring()
        if i%log ==0 or i ==len(images_list)-1:
            print('-------------processing:%d--------'%(i))
            print('current image_path=%s' % (image_path),'shape:{}'.format(image.shape),'labels:{}'.format(labels))
        label=labels[0]
        example = tf.train.Example(features=tf.train.Features(feature={
            'image_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])),
            'height': tf.train.Feature(int64_list=tf.train.Int64List(value=[image.shape[0]])),
            'width': tf.train.Feature(int64_list=tf.train.Int64List(value=[image.shape[1]])),
            'depth': tf.train.Feature(int64_list=tf.train.Int64List(value=[image.shape[2]])),
            'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
            }))
        writer.write(example.SerializeToString())
    writer.close()
def disp_records(record_file,resize_heighe,resize_width,show_nums=4):
    tf_image,tf_label =read_record(record_file,resize_heighe,resize_width,type ='normalization')
    init_op =tf.initialize_all_variable()
    with tf.Session() as sess:
        sess.run(init_op)
        coord =tf.train.Coordonator()
        threads =tf.train.start_queue_runners(sess=sess,coord =coord)
        for i in range(show_nums):
            image,label =sess.run([tf_image,tf_label])
            # image = tf_image.eval()
            # 直接从record解析的image是一个向量,需要reshape显示
            # image = image.reshape([height,width,depth])
            print('shape:{},tpye:{},labels:{}'.format(image.shape,image.dtype,label))
            # pilimg = Image.fromarray(np.asarray(image_eval_reshape))
            # pilimg.show()
            show_image("image:%d"%(label),image)
        coord.request_stop()
        coord.join(threads)
            
def batch_test(record_file,resize_height, resize_width):
    '''
    :param record_file: record文件路径
    :param resize_height:
    :param resize_width:
    :return:
    :PS:image_batch, label_batch一般作为网络的输入
    '''
    # 读取record函数
    tf_image,tf_label = read_record(record_file,resize_height,resize_width,type='normalization')
    image_batch, label_batch= get_batch_images(tf_image,tf_label,batch_size=4,labels_nums=5,one_hot=False,shuffle=False)
 
    init = tf.global_variables_initializer()
    with tf.Session() as sess:  # 开始一个会话
        sess.run(init)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        for i in range(4):
            # 在会话中取出images和labels
            images, labels = sess.run([image_batch, label_batch])
            # 这里仅显示每个batch里第一张图片
            show_image("image", images[0, :, :, :])
            print('shape:{},tpye:{},labels:{}'.format(images.shape,images.dtype,labels))
 
        # 停止所有线程
        coord.request_stop()
        coord.join(threads)    

if __name__ == '__main__':
    # 参数设置
 
    resize_height = 199  # 指定存储图片高度
    resize_width = 199  # 指定存储图片宽度
    shuffle=True
    log=5
    # 产生train.record文件
    image_dir='E:/learning/musemart/dataset_updated/training_set'
    train_labels = 'E:/learning/musemart/dataset_updated/training_set/art.txt'  # 图片路径
    train_record_output = 'E:/learning/musemart/dataset_updated/training_set/train.tfrecords'
    create_records(image_dir,train_labels, train_record_output, resize_height, resize_width,shuffle,log)
    train_nums=get_example_nums(train_record_output)
    print("save train example nums={}".format(train_nums))
 
    # 产生val.record文件
    #image_dir='dataset/val'
    #val_labels = 'dataset/val.txt'  # 图片路径
    #val_record_output = 'dataset/record/val.tfrecords'
    #create_records(image_dir,val_labels, val_record_output, resize_height, resize_width,shuffle,log)
    #val_nums=get_example_nums(val_record_output)
    #print("save val example nums={}".format(val_nums))
 
    # 测试显示函数
    # disp_records(train_record_output,resize_height, resize_width)
    batch_test(train_record_output,resize_height, resize_width)
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值