前言
BigEarthnet数据集就不花笔墨介绍了,查到了这篇文章想必也都有所了解了。该文分享对BigEarthnet-MM数据集划分为训练集,测试集与验证集的官方代码进行的整理与改动,将tensorflow1版本的源代码改动为适配tensorflow2。由于作者在查找BigEarthnet数据集的相关信息时发现都是简单的介绍,而实际数据集在官方示例中还需要进行划分与转化为TFrecord格式进行保存。所以作者也是花了一些时间去官网上学习了一番,踩了一些坑也遇到一些bug,特此分享。
下载数据集
在进行介绍前,先粘贴一下bigearthnet的网址https://bigearth.net/,在官网上可以了解到更多相关信息。在数据划分之前需要提前下载好bigearthnet-MM数据集,官网上对BigEarthnet-MM的介绍就是S1和S2的集合,所以分别下载S1和S2就行了,数据集蛮大的,需要预留出足够的硬盘空间。
官方代码
BigEarthnet-MM数据集解压后需要进行数据集划分成为训练集,测试集与验证集。官方提供的划分为19类的代码链接为https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models
官方提供的程序需要配置了GDAL库或者rasterio库,以及1.15版本的tensorflow,python版本为3.6。作者使用的python版本为3.9.18,tensorflow-gpu版本为2.6.0。
作者在下载GDAL库时遇到还未解决的问题,库里配置好了GDAL库(3.6.2),但在运行代码却import不了,降低版本还是不行,目前没找到解决方案,希望有好心人知晓可以在评论区或者私聊分享一下。rasterio库配置好后可以正常使用。
tensorflow_utils.py代码改动
import tensorflow as tf
import numpy as np
import os
import json
# SAR band names to read related GeoTIFF files
band_names_s1 = ["VV", "VH"]
# Spectral band names to read related GeoTIFF files
band_names_s2 = ['B01', 'B02', 'B03', 'B04', 'B05',
'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12']
def prep_example(bands, BigEarthNet_19_labels, BigEarthNet_19_labels_multi_hot, patch_name_s1, patch_name_s2):
return tf.train.Example(
features=tf.train.Features(
feature={
'B01': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B01']))),
'B02': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B02']))),
'B03': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B03']))),
'B04': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B04']))),
'B05': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B05']))),
'B06': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B06']))),
'B07': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B07']))),
'B08': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B08']))),
'B8A': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B8A']))),
'B09': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B09']))),
'B11': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B11']))),
'B12': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B12']))),
"VV": tf.train.Feature(
float_list=tf.train.FloatList(value=np.ravel(bands['VV']))),
"VH": tf.train.Feature(
float_list=tf.train.FloatList(value=np.ravel(bands['VH']))),
'BigEarthNet-19_labels': tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[i.encode('utf-8') for i in BigEarthNet_19_labels])),
'BigEarthNet-19_labels_multi_hot': tf.train.Feature(
int64_list=tf.train.Int64List(value=BigEarthNet_19_labels_multi_hot)),
'patch_name_s1': tf.train.Feature(
bytes_list=tf.train.BytesList(value=[patch_name_s1.encode('utf-8')])),
'patch_name_s2': tf.train.Feature(
bytes_list=tf.train.BytesList(value=[patch_name_s2.encode('utf-8')]))
}))
def create_split(root_folder_s1, root_folder_s2, patch_names, TFRecord_writer, label_indices, GDAL_EXISTED, RASTERIO_EXISTED, UPDATE_JSON):
label_conversion = label_indices['label_conversion']
BigEarthNet_19_label_idx = {v: k for k, v in label_indices['BigEarthNet-19_labels'].items()}
if GDAL_EXISTED:
import gdal
elif RASTERIO_EXISTED:
import rasterio
progress_bar = tf.keras.utils.Progbar(target = len(patch_names)) #原代码:progress_bar = tf.contrib.keras.utils.Progbar(target = len(patch_names)),tensorflow中已无contrib库
for patch_idx, patch_name in enumerate(patch_names):
patch_name_s1, patch_name_s2 = patch_name[1], patch_name[0]
patch_folder_path_s1 = os.path.join(root_folder_s1, patch_name_s1)
patch_folder_path_s2 = os.path.join(root_folder_s2, patch_name_s2)
bands = {}
for band_name in band_names_s1:
band_path = os.path.join(
patch_folder_path_s1, patch_name_s1 + '_' + band_name + '.tif')
if GDAL_EXISTED:
band_ds = gdal.Open(band_path, gdal.GA_ReadOnly)
raster_band = band_ds.GetRasterBand(1)
band_data = raster_band.ReadAsArray()
bands[band_name] = np.array(band_data)
elif RASTERIO_EXISTED:
band_ds = rasterio.open(band_path)
band_data = np.array(band_ds.read(1))
bands[band_name] = np.array(band_data)
for band_name in band_names_s2:
# First finds related GeoTIFF path and reads values as an array
band_path = os.path.join(
patch_folder_path_s2, patch_name_s2 + '_' + band_name + '.tif')
if GDAL_EXISTED:
band_ds = gdal.Open(band_path, gdal.GA_ReadOnly)
raster_band = band_ds.GetRasterBand(1)
band_data = raster_band.ReadAsArray()
bands[band_name] = np.array(band_data)
elif RASTERIO_EXISTED:
band_ds = rasterio.open(band_path)
band_data = np.array(band_ds.read(1))
bands[band_name] = np.array(band_data)
original_labels_multi_hot = np.zeros(
len(label_indices['original_labels'].keys()), dtype=int)
BigEarthNet_19_labels_multi_hot = np.zeros(len(label_conversion),dtype=int)
patch_json_path = os.path.join(
patch_folder_path_s1, patch_name_s1 + '_labels_metadata.json') #原代码:patch_folder_path_s1, patch_name + '_labels_metadata.json')
with open(patch_json_path, 'r') as f: #原代码: with open(patch_json_path, 'rb') as f 删除b
patch_json = json.load(f)
original_labels = patch_json['labels']
for label in original_labels:
original_labels_multi_hot[label_indices['original_labels'][label]] = 1
for i in range(len(label_conversion)):
BigEarthNet_19_labels_multi_hot[i] = (
np.sum(original_labels_multi_hot[label_conversion[i]]) > 0
).astype(int)
BigEarthNet_19_labels = []
for i in np.where(BigEarthNet_19_labels_multi_hot == 1)[0]:
BigEarthNet_19_labels.append(BigEarthNet_19_label_idx[i])
if UPDATE_JSON:
patch_json['BigEarthNet_19_labels'] = BigEarthNet_19_labels
with open(patch_json_path, 'w') as f:#原代码: with open(patch_json_path, 'wb') as f: 删除b
json.dump(patch_json, f)
example = prep_example(
bands,
BigEarthNet_19_labels,
BigEarthNet_19_labels_multi_hot,
patch_name_s1,
patch_name_s2
)
TFRecord_writer.write(example.SerializeToString())
progress_bar.update(patch_idx)
def prep_tf_record_files(root_folder_s1, root_folder_s2, out_folder, split_names, patch_names_list, label_indices, GDAL_EXISTED, RASTERIO_EXISTED, UPDATE_JSON):
try:
writer_list = []
for split_name in split_names:
writer = tf.io.TFRecordWriter(os.path.join(out_folder, split_name + '.tfrecord')) # tensorflow2没有python_io库
# writer = tf.python_io.TFRecordWriter(os.path.join(out_folder, split_name + '.tfrecord'))
writer_list.append(writer)
# writer_list.append(
# tf.python_io.TFRecordWriter(os.path.join(
# out_folder, split_name + '.tfrecord'))
# )
except:
print('ERROR: TFRecord writer is not able to write files')
exit()
for split_idx in range(len(patch_names_list)):
print('INFO: creating the split of', split_names[split_idx], 'is started')
create_split(
root_folder_s1,
root_folder_s2,
patch_names_list[split_idx],
writer_list[split_idx],
label_indices,
GDAL_EXISTED,
RASTERIO_EXISTED,
UPDATE_JSON
)
writer_list[split_idx].close()
修改完毕后cd到项目文件夹,在命令行输入指令
python prep_splits_19_classes.py -r1 root_folder_s1 -r2 root_folder_s2 -o out_folder -n "./BigEarthNet-MM_19-classes_models-master/splits/test.csv" "./BigEarthNet-MM_19-classes_models-master/splits/train.csv" "./BigEarthNet-MM_19-classes_models-master/splits/val.csv" --update_json -l tensorflow
然后正常运行了,out_folder中会建立三个TFrecord文件。但由于数据量太大了,需要跑很长时间,所以笔者在考虑换数据集实验了,耐不住时间了。