tensorflow数据加载之TFRecord

话不多说,干就完了。


TFRecord是什么?

TFRecord是tensorflow中定义的一种数据格式,这种格式的数据在模型训练中方便将数据喂给模型。在初学tensorflow时肯定都实验过tensorflow中自带的mnist手写数字识别的例子,在那个例子中需要下载一个mnist二进制的数据集,那个数据集是经过处理的二进制文件,我们可以直接使用,并不需要关心图像数据的加载、预处理等操作。也就是说在mnist这个例子中,我们直接跳过接触原始数据的机会,仅仅关心网络模型的训练。但是在实际项目中却没有这样现成的数据文件。TFRecord就是这样一种将原始图像数据保存为二进制数据文件的数据格式,只有将原始图像数据转换为二进制的数据(可以简单理解为矩阵)才能被模型使用。

TFRecord的数据存储结构:

TFRecord文件由Example、Features、Feature组成,结构如下:

message Example {
 Features features = 1;
};

message Features{
 map<string,Feature> featrue = 1;
};

message Feature{
    oneof kind{
        BytesList bytes_list = 1;
        FloatList float_list = 2;
        Int64List int64_list = 3;
    }
};

TFRecord使用流程:

原始图像生成TFRecord数据文件:

图像数据格式如下:

生成TFRecordwenj文件:

def save_to_tfrecord():
    def _int64_feature(value):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

    def _bytes_feature(value):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

    dataset_path = os.path.join(".", "dataset")
    tfrecord_file_name = os.path.join(".", "tfrecords_files", "cat_and_dog_datasets.tfrecord")

    tfrecord_writer = tf.python_io.TFRecordWriter(path=tfrecord_file_name)
    for root, dirs, files in os.walk(dataset_path):
        if len(dirs) == 0:
            cls_name = os.path.split(root)[-1]
            for image_name in files:
                image_path = os.path.join(root, image_name)
                img = cv2.imread(filename=image_path)
                if img.ndim == 3:
                    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                    img = cv2.resize(img, (224, 224))
                    img_pixels = img.shape[0] * img.shape[1]

                    img_raw = img.tostring()
                    print(img_pixels)
                    img_example = tf.train.Example(
                        features=tf.train.Features(
                            feature={
                                "pixels": _int64_feature(img_pixels),
                                "labels": _bytes_feature(cls_name.encode("utf8")),
                                "image_raw": _bytes_feature(img_raw)
                            }
                        )
                    )
                    tfrecord_writer.write(img_example.SerializeToString())
    tfrecord_writer.close()

加载TFRecord数据文件:

def load_from_tfrecord():
    tfrecord_reader = tf.TFRecordReader()
    tfrecord_file_name = ["tfrecords_files/cat_and_dog_datasets.tfrecord"]

    tfrecord_file_queue = tf.train.string_input_producer(
        string_tensor=tfrecord_file_name,
        shuffle=True,
        num_epochs=200
    )

    _, serialized_example = tfrecord_reader.read(queue=tfrecord_file_queue)

    features = tf.parse_single_example(
        serialized=serialized_example,
        features={
            "pixels": tf.FixedLenFeature(shape=[], dtype=tf.int64),
            "labels": tf.FixedLenFeature(shape=[], dtype=tf.string),
            "image_raw": tf.FixedLenFeature(shape=[], dtype=tf.string)
        }
    )

    pixels = tf.cast(features["pixels"], tf.int64)
    labels = tf.cast(features["labels"], tf.string)
    img_raw = tf.decode_raw(bytes=features["image_raw"], out_type=tf.uint8)

    init_op = [tf.local_variables_initializer(), tf.global_variables_initializer()]
    sess.run(init_op)

    # 这两行代码至关重要,表示使用协程启动数据加载线程
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    for i in range(4):
        print(i * "#####")
        img_pixels, img_labels, img = sess.run([pixels, labels, img_raw])
        print("img pixels -> %s" % img_pixels)
        print("img labels -> %s" % img_labels)
        print("img matrix -> %s" % img.shape)

        plt.imshow(img.reshape((224, 224, 3)))
        plt.show()

参考:https://blog.csdn.net/zxyhhjs2017/article/details/82556746

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值