一、环境要求
Python==3.7
Tensorflow==2.4.0
二、数据准备
本博客演示的数据集可以通过以下链接进行下载
百度网盘提取码:9q0f
三、TFRecord文件的创建
import os
import cv2
import tensorflow as tf
import tqdm as tqdm
if __name__ == '__main__':
# 需要处理的数据集根目录
Dataset_path = 'Split_Dataset'
# 用于生成需要写入的文件路径
file_names = [i.split('.')[0] for i in os.listdir(os.path.join(Dataset_path, 'DSM'))]
dsm_filename = [os.path.join(Dataset_path, 'DSM', i + '.tif') for i in file_names]
label_filename = [os.path.join(Dataset_path, 'Label', i + '.png') for i in file_names]
rgb_filename = [os.path.join(Dataset_path, 'RGB', i + '.png') for i in file_names]
# 开始进行TFRecords文件的构建
with tf.io.TFRecordWriter('Potsdam.tfrecords') as writer:
tqdm_file = tqdm.tqdm(iterable=zip(dsm_filename, rgb_filename, label_filename), total=len(dsm_filename))
for dsm, rgb, label in tqdm_file:
# 进行高程图的读取
dsm_data = cv2.imread(dsm, -1)
dsm_data = dsm_data.ravel()
# 进行光学图的读取
rgb_data = open(rgb, 'rb').read()
# 进行标签的读取
label_data = open(label, 'rb').read()
# 建立tf.train.Feature字典
feature = {
'dsm': tf.train.Feature(float_list=tf.train.FloatList(value=dsm_data)),
'rgb': tf.train.Feature(bytes_list=tf.train.BytesList(value=[rgb_data])),
'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[label_data]))
}
# 通过字典创建Example
example = tf.train.Example(features=tf.train.Features(feature=feature))
# 将Example序列化并写入TFRecords文件
writer.write(example.SerializeToString())
tqdm_file.close()
四、TFRecord文件的读取
import tensorflow as tf
import matplotlib.pyplot as plt
# 构造Feature结构,告诉解码器每个Feature是什么
feature_description = {
'dsm': tf.io.FixedLenFeature([512, 512, 1], tf.float32),
'rgb': tf.io.FixedLenFeature([], tf.string),
'label': tf.io.FixedLenFeature([], tf.string)
}
# Example的解析函数
def parse_example(example_string):
feature = tf.io.parse_single_example(serialized=example_string, features=feature_description)
feature['rgb'] = tf.image.decode_png(feature['rgb'], channels=3)
feature['rgb'] = tf.image.resize(feature['rgb'], [512, 512])
feature['rgb'] = tf.cast(feature['rgb'], tf.float32)
feature['rgb'] = feature['rgb'] / 255
feature['label'] = tf.image.decode_png(feature['label'], channels=1)
feature['label'] = tf.image.resize(feature['label'], [512, 512])
feature['label'] = tf.cast(feature['label'], tf.int64)
feature['label'] = tf.squeeze(feature['label'])
return feature['dsm'], feature['rgb'], feature['label']
if __name__ == '__main__':
dataset = tf.data.TFRecordDataset(r'Potsdam.tfrecords')
dataset = dataset.map(parse_example)
dataset = dataset.shuffle(buffer_size=500)
for dsm, rgb, label in dataset.take(1):
temp = (dsm - tf.reduce_min(input_tensor=dsm, axis=(0, 1, 2))) / (
tf.reduce_max(input_tensor=dsm, axis=(0, 1, 2)) - tf.reduce_min(input_tensor=dsm, axis=(0, 1, 2)))
plt.figure(figsize=(15, 5), dpi=150)
plt.subplot(1, 3, 1)
plt.imshow(temp)
plt.axis('off')
plt.subplot(1, 3, 2)
plt.imshow(rgb)
plt.axis('off')
plt.subplot(1, 3, 3)
plt.imshow(label)
plt.axis('off')
plt.show()
五、TFRecord文件的读取结果
六、项目文件的下载
本文的项目源文件点击以下链接进行获取
项目源文件