制作自己的数据集tfrecord格式

本文介绍了如何使用TensorFlow创建多分类数据集的tfrecord格式,包括解决将多类标签转换为整型的问题。作者分享了遇到的挑战及解决方案,适合机器学习新手参考。
摘要由CSDN通过智能技术生成

最近接触TensorFlow,需要训练自己的数据集,看到很多博客资料,了解到TensorFlow中自带的tfrecord文件,但是自己具体实现起来发现自己的情况与资料的一些不太一样,所以把自己遇到的问题归纳整理出来。新手一枚,水平有限,有许多问题的解决可能仅限于解决,代码并有优化,有些思路可能走了弯路,希望能跟大家交流。

1.问题1:对于多分类情况,怎么确定标签?

(1)多分类:大多资料中给出的是针对两种分类的情况,采用的是直接用class={class1 , class2}这种格式,但是对于很多类的话,依次写出类别有点麻烦,那么可以采用先定义一个列表classes = []来存储目录中所有的分类,比如对于字符识别,那么classes = {1,2,3,4...,A,B,C},然后用for index, name in enumerate(classes)将对应文件夹的名字与整数一一对应起来。

其中enumerate是python中的一个函数,目的会将index和classes中的name对应起来,比如classes = {1,2,3,A,...}那么index = {1, 2, 3, 4,..}并且与classes中的1,2,3 ,A这些对应。

为什么要这样做?因为在存入tfrecord的时候,标签一般用的是整型,当目录文件中包含A,B,或者字符串的时候要将其变为整型,我尝试过读入tfrecord的时候用tobyte格式,也就是直接用字符串的形式读入,但是会报错,也可能是我知识水平不够,没有找到正确的方法。

classes = []
    for class1 in os.listdir(cwd):
        classes.append(class1)
    for index, name in enumerate(classes):
        class_path = cwd + name + '\\'
        for img_name in os.listdir(class_path):
            img_path = class_path + img_name  # 每一个图片的地址

2.问题二:如何在读入的时候分数据集和测试集(其中测试集占50%)
(1)我采用的是在逐层访问文件夹的时候用两个字典(一对多)存入图片的标签和对应图片地址。
    for class1 in os.listdir(cwd):
        classes.append(class1)
    for index, name in enumerate(classes):
        class_path = cwd + name + '\\'
        for img_name in os.listdir(class_path):
            img_path = class_path + img_name  # 每一个图片的地址
            m += 1
            if(m % 5) == 0:
                len_testing_dataset += 1
                testing_dataset[index].append(img_path)
            else:
                len_training_dataset += 1
                training_dataset[index].append(img_path)

3.问题三:数据格式的变换
(1)从tfrecord中读出的数据格式是tensor格式,我之前跟着教程构建的带有计算图的CNN,它输入的数据格式和mnist数据集是一
样的,那么要将输出的tensor格式转化为与mnist数据集一样的格式,并且标签采用one-hot编码格式
def to_one_hot(classes, label):
    num_classes = len(classes)
    # print(num_classes)
    # print("label-----------",label)
    label_arr = np.zeros((num_classes))
    # print("label_arr---------",label_arr)
    label_arr[label] += 1.0
    # print("after change label_arr",label_arr)
    return label_arr

def importimg(imagepath,m,classes):
    #imagepath为读入的图片tfrecord的地址
    #imagepath = "data_train.tfrecords"
    # min_after_dequeue = 15
    # batch_size = 1
    # capacity = min_after_dequeue + 3 * batch_size
    # print(imagepath)
    # print("m------------",m)
    print("开始读入数据----------------------------------")
    filename_queue = tf.train.string_input_producer([imagepath]) #读入流中
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)   #返回文件名和文件
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw': tf.FixedLenFeature([], tf.string),
                                       })  # 取出包含image和label的feature对象
    image = tf.decode_raw(features['img_raw'], tf.uint8)
    # print("从tfrecord文件中读取数据image", image)
    image = tf.reshape(image, [-1])
    # print("after reshape of image-----------------",image)
    label = tf.cast(features['label'], tf.int32)  # 在流中抛出label张量
    with tf.Session() as sess:  # 开始一个会话
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        labels = []
        images = []
        for i in range(m):
            print("第", i, "个",imagepath,"数据正在读取中")
            image1, label1 = sess.run([image, label])  # 在会话中取出image和label
            image = tf.cast(image1, tf.float32)
            label_arr = to_one_hot(classes, label1)
            labels.append(label_arr)
            images.append(image1)
            labels_arr = np.array(labels)
            images_arr = np.array(images)
        # print("labels_arr------------",labels_arr)
        # print("images_arr------------",images_arr)
        coord.request_stop()
        coord.join(threads)
    return images_arr, labels_arr

总的代码:
import os
import tensorflow as tf
from PIL import Image
from collections import defaultdict
from itertools import groupby
#import matplotlib.pyplot as plt
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import numpy as np

#读图片地址,CNN,已经测试正确
def read_image(cwd):
    #m记录样本数
    m = 0
    classes = []
    len_testing_dataset = 0
    len_training_dataset = 0
    training_dataset = defaultdict(list)
    testing_dataset = defaultdict(list)
    for class1 in os.listdir(cwd):
        classes.append(class1)
    for index, name in enumerate(classes):
        class_path = cwd + name + '\\'
        for img_name in os.listdir(class_path):
            img_path = class_path + img_name  # 每一个图片的地址
            m += 1
            if(m % 5) == 0:
                len_testing_dataset += 1
                testing_dataset[index].append(img_path)
            else:
                len_training_dataset += 1
                training_dataset[index].append(img_path)
    print("training_dataset testing_dataset END ------------------------------------------------------")
    return m, classes, training_dataset, testing_dataset, len_testing_dataset, len_training_dataset

# m, classes, training_dataset, testing_dataset, len_testing_dataset, len_training_dataset = read_image(
#      'E:\datafortest\Testlib1\\'
# )


#CNN,写数据,已经测试正确
def write_data(dataset, newfilepath):
    writer = tf.python_io.TFRecordWriter(newfilepath)  # 要生成的文件
    for label, img in dataset.items():
        for img_path in img:
            print("img_path------------",img_path)
            img = Image.open(img_path)
            img = img.resize((15, 15))
            img_raw = img.tobytes()  # 将图片转化为二进制格式,uint8
            #img_decode = img_raw.decode('utf-8')
            #print(img_decode)
            example = tf.train.Example(features=tf.train.Features(feature={
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
                'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
            }))  # example对象对label和image数据进行封装
            writer.write(example.SerializeToString())  # 序列化为字符串
    print("成功存入tfrecord文件")
    writer.close()
    # 返回总的类别数,和所有的类别标号

#CNN,写数据,已经测试正确
def write_mul_data(dataset, newfilepath, record_location):
    writer = None
    current_index = 0
    for label, img in dataset.items():
        for img_path in img:
            print("img_path------------", img_path)
            #每隔10000个就存入一个文件
            if current_index % 10000 == 0:
                if writer:
                    writer.close()
                record_filename = "{record_location} - {current_index}.tfrecords".format(
                    record_location = record_location,
                    current_index = current_index
                )
                print(record_filename + "----------------------------")
            current_index += 1
            image_file = tf.read_file(newfilepath)
            try:
                image = tf.image.decode_jpeg(newfilepath)
            except:
                print(image_file)
                continue






# write_data(training_dataset,"train_set.tfrecords")

def to_one_hot(classes, label):
    num_classes = len(classes)
    # print(num_classes)
    # print("label-----------",label)
    label_arr = np.zeros((num_classes))
    # print("label_arr---------",label_arr)
    label_arr[label] += 1.0
    # print("after change label_arr",label_arr)
    return label_arr

#CNN,这个方法就是将tensor张量转化为images转化为int数组和label转化为ont-hot编码
def importimg(imagepath,m,classes):
    #imagepath为读入的图片tfrecord的地址
    #imagepath = "data_train.tfrecords"
    # min_after_dequeue = 15
    # batch_size = 1
    # capacity = min_after_dequeue + 3 * batch_size
    # print(imagepath)
    # print("m------------",m)
    print("开始读入数据----------------------------------")
    filename_queue = tf.train.string_input_producer([imagepath]) #读入流中
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)   #返回文件名和文件
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw': tf.FixedLenFeature([], tf.string),
                                       })  # 取出包含image和label的feature对象
    image = tf.decode_raw(features['img_raw'], tf.uint8)
    # print("从tfrecord文件中读取数据image", image)
    image = tf.reshape(image, [-1])
    # print("after reshape of image-----------------",image)
    label = tf.cast(features['label'], tf.int32)  # 在流中抛出label张量
    with tf.Session() as sess:  # 开始一个会话
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        labels = []
        images = []
        for i in range(m):
            print("第", i, "个",imagepath,"数据正在读取中")
            image1, label1 = sess.run([image, label])  # 在会话中取出image和label
            image = tf.cast(image1, tf.float32)
            label_arr = to_one_hot(classes, label1)
            labels.append(label_arr)
            images.append(image1)
            labels_arr = np.array(labels)
            images_arr = np.array(images)
        # print("labels_arr------------",labels_arr)
        # print("images_arr------------",images_arr)
        coord.request_stop()
        coord.join(threads)
    return images_arr, labels_arr

参考资料:http://blog.csdn.net/xierhacker/article/details/72357651


制作自己的数据集可以按照以下步骤进行: 1. 确定数据集的目的和主题:首先,明确你想要构建数据集的目的和主题,例如自然语言处理、计算机视觉等。 2. 收集和筛选数据:根据你的主题,在互联网上搜索相关的数据源,或者创建自己的数据。确保数据的质量和准确性,同时尽量涵盖不同的情况和变化。 3. 数据清洗和预处理:对收集到的数据进行清洗和预处理,以去除无效或冗余的数据,并将数据转换为适合模型训练的格式。这可能包括文本清洗、图像裁剪、标注等操作。 4. 标注和注释数据:根据你的需求,对数据进行标注和注释,以便训练模型能够理解和学习数据的含义。例如,对文本数据可以进行分类、命名实体识别等标注,对图像数据可以进行目标检测、分割等注释。 5. 划分训练集和测试集:将数据集划分为训练集和测试集,用于模型的训练和评估。通常,训练集用于模型的训练,测试集用于评估模型的性能。 6. 数据增强(可选):如果你的数据量有限,可以使用数据增强技术生更多的训练样本。例如,对图像进行旋转、翻转、缩放等操作,对文本进行词语替换、重排等操作。 7. 数据集格式:根据你使用的模型和框架要求,将数据集保存为特定的格式,如CSV、JSON、TFRecord等。 8. 数据集的文档和元数据:为了方便其他人使用你的数据集,你可以提供相关的文档和元数据,包括数据集的描述、格式说明、标注规范等。 9. 数据集的分享与发布:如果你希望与他人共享你的数据集,可以将其上传到数据集共享平台或者在论文、博客等中公开分享。 请注意,在制作自己的数据集时,需要遵守相关的法律法规和道德准则,尊重数据的隐私和版权。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值