tf创建tfRecord文件

项目详细请猛戳我的github地址,直接可运行:https://github.com/SamXiaosheng/create-tfRecord

下面是main文件代码和create tfRecord文件:

import tensorflow as tf
from tfRecord import *
import cv2

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('image_dir', './image/',
                           """Directory where to write event logs """)

def main(_):
    create_tfrecords(FLAGS.image_dir)
    image_batch,label_batch =read_and_decode('test.tfRecord')
    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        while not coord.should_stop():
            image,label = sess.run([image_batch,label_batch])
            print(label)
            cv2.imshow('image',image[0])
            cv2.waitKey(200)
        coord.request_stop()
        coord.join(threads)

if __name__ == '__main__':
    tf.app.run()


import tensorflow as tf
import numpy as np
import os
import cv2

def read_and_decode(filename):
    filename_queue = tf.train.string_input_producer([filename])

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw' : tf.FixedLenFeature([], tf.string),
                                       })

    img = tf.decode_raw(features['img_raw'], tf.uint8)#这里的格式非常重要
    img = tf.reshape(img, [227, 227, 3])
    #img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
    label = tf.cast(features['label'], tf.uint8)

    image_batch, label_batch = tf.train.shuffle_batch([img, label],
                                                    batch_size=1,#这里参数设置目的是每次只读取一个样本
                                                    capacity=1,
                                                    min_after_dequeue=0)
    #label_batch = tf.one_hot(label_batch, NUM_CLASSES)
    #label_batch = tf.cast(label_batch, dtype=tf.int64)
    #label_batch = tf.reshape(label_batch, [batch_size, NUM_CLASSES])

    return image_batch, label_batch
#读取某目录路径下的所有文件,返回图片的名称列表
def dirtomdfbatchmsra(dirpath):#读取目录下训练图像和对应的label
    image_ext = 'jpg'
    images = [fn for fn in os.listdir(dirpath) if fn.endswith(image_ext)]#返回dirpath路径下所有后缀jpg文件
    images.sort()#排序的目的有利于样本和标签的对应
    #print(images)
    gt_ext = 'png'
    gt_maps = [fn for fn in os.listdir(dirpath) if fn.endswith(gt_ext)]
    gt_maps.sort()
    #print(gt_maps)
    return gt_maps,images#返回gt图和训练image的所有文件名

def create_tfrecords(image_dir):
    writer = tf.python_io.TFRecordWriter("test.tfRecord")
    image_png,image_jpg = dirtomdfbatchmsra(image_dir)
    for index, name in enumerate(image_jpg):
            img = cv2.imread(image_dir+name).astype(np.uint8)
            img = cv2.resize(img,(227,227))#统一大小
            img_raw = img.tobytes()#转换成字节形式
            example = tf.train.Example(features=tf.train.Features(feature={
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
                'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
            }))
            writer.write(example.SerializeToString())
    for index, name in enumerate(image_png):
            img = cv2.imread(image_dir+name).astype(np.uint8)
            img = cv2.resize(img,(227, 227))#
            img_raw = img.tobytes()
            example = tf.train.Example(features=tf.train.Features(feature={
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
                'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
            }))
            writer.write(example.SerializeToString())
    writer.close()




  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值