自己数据制作tfrecords格式的数据集

数据集格式,将所有测试或者训练数据集各自保存在一个文件夹下:如下图:训练数据集

制作.tfrecords格式数据集的代码:

import os
import tensorflow as tf
from PIL import Image  #注意Image,后面会用到

# Imagenet图片都保存在/data目录下,里面有1000个子目录,获取这些子目录的名字
classes = os.listdir('G:\\train\\')

cwd1='G:\\test\\'
cwd2='G:\\train\\'


print(classes)
writer1 = tf.python_io.TFRecordWriter("test.tfrecords")  # 要生成的文件
writer2 = tf.python_io.TFRecordWriter("train.tfrecords")  # 要生成的文件


for index in range(10):
    print(index)
    name=classes[index]
    print(name)
    class_path1 = cwd1 + name + '/'     #训练数据集的路径
    class_path2 = cwd2 + name + '/'     #测试数据集的路径
    for img_name in os.listdir(class_path1):
        img_path = class_path1 + img_name  # 每一个图片的地址
        img = Image.open(img_path)
        img = img.resize((32, 32))
        img_raw = img.tobytes()  # 将图片转化为二进制格式
        example = tf.train.Example(features=tf.train.Features(feature={
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
        }))  # example对象对label和image数据进行封装
        writer1.write(example.SerializeToString())  # 序列化为字符串
    for img_name2 in os.listdir(class_path2):
        img_path2 = class_path2 + img_name2  # 每一个图片的地址

        img2 = Image.open(img_path2)
        img2 = img2.resize((64, 64))
        img_raw2 = img2.tobytes()  # 将图片转化为二进制格式
        example2 = tf.train.Example(features=tf.train.Features(feature={
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw2]))
        }))  # example对象对label和image数据进行封装
        writer2.write(example2.SerializeToString())  # 序列化为字符串
writer1.close()
writer2.close()

 

读取数据集代码:

def read_and_decode(filename):  # 读入dog_train.tfrecords
    filename_queue = tf.train.string_input_producer([filename])  # 生成一个queue队列

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)  # 返回文件名和文件
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw': tf.FixedLenFeature([], tf.string),
                                       })  # 将image数据和label取出来
    img = tf.decode_raw(features['img_raw'], tf.uint8)
    img = tf.reshape(img, [32 * 32 * 3])  # reshape为128*128的3通道图片
    img = tf.cast(img, tf.float32) * (1. / 255)
    label = tf.cast(features['label'], tf.int32)  # 在流中抛出label张量
    return img, label

#打乱数据
def createBatch(filename, batchsize):
    images, labels = read_and_decode(filename)

    min_after_dequeue = batchsize
    capacity = 3 * batchsize
    image_batch, label_batch = tf.train.shuffle_batch([images, labels],
                                                      batch_size=batchsize,
                                                      capacity=capacity,
                                                      min_after_dequeue=min_after_dequeue
                                                      )
    label_batch = tf.one_hot(label_batch, depth=10)
    return image_batch, label_batch

使用细节多个tfrecord文件


data_dir = '/home/sanyuan/dataset_animal/dataset_tfrecords/' 
 
filenames = [os.path.join(data_dir,'train%d.tfrecords' % ii) for ii in range(1)] #如果有多个文件,直接更改这里即可
filename_queue = tf.train.string_input_producer(filenames)
image = read_and_decode(filename_queue)
tfrecord本质是创建一个文件队列,创建一个内存队列,内存队列不需要创建,文件队列则需要通过tfrecords文件名来创建。

shuffle参数判断远问是否需要打乱,num_epochs迭代的代数100

filename_queue = tf.train.string_input_producer([filename],shuffle=True,num_epochs=100)  # 生成一个queue队列
  • 8
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值