Tensorflow---TFRecord文件的创建与读取

一、环境要求

Python==3.7
Tensorflow==2.4.0

二、数据准备

本博客演示的数据集可以通过以下链接进行下载
百度网盘提取码:9q0f

三、TFRecord文件的创建

import os
import cv2
import tensorflow as tf
import tqdm as tqdm

if __name__ == '__main__':
    # 需要处理的数据集根目录
    Dataset_path = 'Split_Dataset'
    # 用于生成需要写入的文件路径
    file_names = [i.split('.')[0] for i in os.listdir(os.path.join(Dataset_path, 'DSM'))]
    dsm_filename = [os.path.join(Dataset_path, 'DSM', i + '.tif') for i in file_names]
    label_filename = [os.path.join(Dataset_path, 'Label', i + '.png') for i in file_names]
    rgb_filename = [os.path.join(Dataset_path, 'RGB', i + '.png') for i in file_names]
    # 开始进行TFRecords文件的构建
    with tf.io.TFRecordWriter('Potsdam.tfrecords') as writer:
        tqdm_file = tqdm.tqdm(iterable=zip(dsm_filename, rgb_filename, label_filename), total=len(dsm_filename))
        for dsm, rgb, label in tqdm_file:
            # 进行高程图的读取
            dsm_data = cv2.imread(dsm, -1)
            dsm_data = dsm_data.ravel()

            # 进行光学图的读取
            rgb_data = open(rgb, 'rb').read()

            # 进行标签的读取
            label_data = open(label, 'rb').read()

            # 建立tf.train.Feature字典
            feature = {
                'dsm': tf.train.Feature(float_list=tf.train.FloatList(value=dsm_data)),
                'rgb': tf.train.Feature(bytes_list=tf.train.BytesList(value=[rgb_data])),
                'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[label_data]))
            }

            # 通过字典创建Example
            example = tf.train.Example(features=tf.train.Features(feature=feature))
            # 将Example序列化并写入TFRecords文件
            writer.write(example.SerializeToString())
        tqdm_file.close()

四、TFRecord文件的读取

import tensorflow as tf
import matplotlib.pyplot as plt

# 构造Feature结构,告诉解码器每个Feature是什么
feature_description = {
    'dsm': tf.io.FixedLenFeature([512, 512, 1], tf.float32),
    'rgb': tf.io.FixedLenFeature([], tf.string),
    'label': tf.io.FixedLenFeature([], tf.string)
}


# Example的解析函数
def parse_example(example_string):
    feature = tf.io.parse_single_example(serialized=example_string, features=feature_description)

    feature['rgb'] = tf.image.decode_png(feature['rgb'], channels=3)
    feature['rgb'] = tf.image.resize(feature['rgb'], [512, 512])
    feature['rgb'] = tf.cast(feature['rgb'], tf.float32)
    feature['rgb'] = feature['rgb'] / 255

    feature['label'] = tf.image.decode_png(feature['label'], channels=1)
    feature['label'] = tf.image.resize(feature['label'], [512, 512])
    feature['label'] = tf.cast(feature['label'], tf.int64)
    feature['label'] = tf.squeeze(feature['label'])

    return feature['dsm'], feature['rgb'], feature['label']


if __name__ == '__main__':
    dataset = tf.data.TFRecordDataset(r'Potsdam.tfrecords')
    dataset = dataset.map(parse_example)
    dataset = dataset.shuffle(buffer_size=500)

    for dsm, rgb, label in dataset.take(1):
        temp = (dsm - tf.reduce_min(input_tensor=dsm, axis=(0, 1, 2))) / (
                tf.reduce_max(input_tensor=dsm, axis=(0, 1, 2)) - tf.reduce_min(input_tensor=dsm, axis=(0, 1, 2)))

        plt.figure(figsize=(15, 5), dpi=150)
        plt.subplot(1, 3, 1)
        plt.imshow(temp)
        plt.axis('off')

        plt.subplot(1, 3, 2)
        plt.imshow(rgb)
        plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.imshow(label)
        plt.axis('off')

        plt.show()

五、TFRecord文件的读取结果

在这里插入图片描述

六、项目文件的下载

本文的项目源文件点击以下链接进行获取
项目源文件

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

水哥很水

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值