对BigEarthnet-MM数据集官方spilt代码调试的分享(tensorflow2)


前言

BigEarthnet数据集就不花笔墨介绍了,查到了这篇文章想必也都有所了解了。该文分享对BigEarthnet-MM数据集划分为训练集,测试集与验证集的官方代码进行的整理与改动,将tensorflow1版本的源代码改动为适配tensorflow2。由于作者在查找BigEarthnet数据集的相关信息时发现都是简单的介绍,而实际数据集在官方示例中还需要进行划分与转化为TFrecord格式进行保存。所以作者也是花了一些时间去官网上学习了一番,踩了一些坑也遇到一些bug,特此分享。


下载数据集

在进行介绍前,先粘贴一下bigearthnet的网址https://bigearth.net/,在官网上可以了解到更多相关信息。在数据划分之前需要提前下载好bigearthnet-MM数据集,官网上对BigEarthnet-MM的介绍就是S1和S2的集合,所以分别下载S1和S2就行了,数据集蛮大的,需要预留出足够的硬盘空间。


官方代码

BigEarthnet-MM数据集解压后需要进行数据集划分成为训练集,测试集与验证集。官方提供的划分为19类的代码链接为https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models

官方提供的程序需要配置了GDAL库或者rasterio库,以及1.15版本的tensorflow,python版本为3.6。作者使用的python版本为3.9.18,tensorflow-gpu版本为2.6.0。

作者在下载GDAL库时遇到还未解决的问题,库里配置好了GDAL库(3.6.2),但在运行代码却import不了,降低版本还是不行,目前没找到解决方案,希望有好心人知晓可以在评论区或者私聊分享一下。rasterio库配置好后可以正常使用。


tensorflow_utils.py代码改动

import tensorflow as tf
import numpy as np
import os
import json

# SAR band names to read related GeoTIFF files
band_names_s1 = ["VV", "VH"]

# Spectral band names to read related GeoTIFF files
band_names_s2 = ['B01', 'B02', 'B03', 'B04', 'B05',
                 'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12']

def prep_example(bands, BigEarthNet_19_labels, BigEarthNet_19_labels_multi_hot, patch_name_s1, patch_name_s2):
    return tf.train.Example(
            features=tf.train.Features(
                feature={
                    'B01': tf.train.Feature(
                        int64_list=tf.train.Int64List(value=np.ravel(bands['B01']))),
                    'B02': tf.train.Feature(
                        int64_list=tf.train.Int64List(value=np.ravel(bands['B02']))),
                    'B03': tf.train.Feature(
                        int64_list=tf.train.Int64List(value=np.ravel(bands['B03']))),
                    'B04': tf.train.Feature(
                        int64_list=tf.train.Int64List(value=np.ravel(bands['B04']))),
                    'B05': tf.train.Feature(
                        int64_list=tf.train.Int64List(value=np.ravel(bands['B05']))),
                    'B06': tf.train.Feature(
                        int64_list=tf.train.Int64List(value=np.ravel(bands['B06']))),
                    'B07': tf.train.Feature(
                        int64_list=tf.train.Int64List(value=np.ravel(bands['B07']))),
                    'B08': tf.train.Feature(
                        int64_list=tf.train.Int64List(value=np.ravel(bands['B08']))),
                    'B8A': tf.train.Feature(
                        int64_list=tf.train.Int64List(value=np.ravel(bands['B8A']))),
                    'B09': tf.train.Feature(
                        int64_list=tf.train.Int64List(value=np.ravel(bands['B09']))),
                    'B11': tf.train.Feature(
                        int64_list=tf.train.Int64List(value=np.ravel(bands['B11']))),
                    'B12': tf.train.Feature(
                        int64_list=tf.train.Int64List(value=np.ravel(bands['B12']))),
                    "VV":  tf.train.Feature(
                        float_list=tf.train.FloatList(value=np.ravel(bands['VV']))),
                    "VH":  tf.train.Feature(
                        float_list=tf.train.FloatList(value=np.ravel(bands['VH']))),
                    'BigEarthNet-19_labels': tf.train.Feature(
                        bytes_list=tf.train.BytesList(
                            value=[i.encode('utf-8') for i in BigEarthNet_19_labels])),
                    'BigEarthNet-19_labels_multi_hot': tf.train.Feature(
                        int64_list=tf.train.Int64List(value=BigEarthNet_19_labels_multi_hot)),
                    'patch_name_s1': tf.train.Feature(
                        bytes_list=tf.train.BytesList(value=[patch_name_s1.encode('utf-8')])),
                    'patch_name_s2': tf.train.Feature(
                        bytes_list=tf.train.BytesList(value=[patch_name_s2.encode('utf-8')]))
                }))
    
def create_split(root_folder_s1, root_folder_s2, patch_names, TFRecord_writer, label_indices, GDAL_EXISTED, RASTERIO_EXISTED, UPDATE_JSON):
    label_conversion = label_indices['label_conversion']
    BigEarthNet_19_label_idx = {v: k for k, v in label_indices['BigEarthNet-19_labels'].items()}
    if GDAL_EXISTED:
        import gdal
    elif RASTERIO_EXISTED:
        import rasterio
    progress_bar = tf.keras.utils.Progbar(target = len(patch_names)) #原代码:progress_bar = tf.contrib.keras.utils.Progbar(target = len(patch_names)),tensorflow中已无contrib库
    for patch_idx, patch_name in enumerate(patch_names):
        patch_name_s1, patch_name_s2 = patch_name[1], patch_name[0]
        patch_folder_path_s1 = os.path.join(root_folder_s1, patch_name_s1)
        patch_folder_path_s2 = os.path.join(root_folder_s2, patch_name_s2)

        bands = {}
        for band_name in band_names_s1:
            band_path = os.path.join(
                patch_folder_path_s1, patch_name_s1 + '_' + band_name + '.tif')
            if GDAL_EXISTED:
                band_ds = gdal.Open(band_path,  gdal.GA_ReadOnly)
                raster_band = band_ds.GetRasterBand(1)
                band_data = raster_band.ReadAsArray()
                bands[band_name] = np.array(band_data)
            elif RASTERIO_EXISTED:
                band_ds = rasterio.open(band_path)
                band_data = np.array(band_ds.read(1))
                bands[band_name] = np.array(band_data)

        for band_name in band_names_s2:
            # First finds related GeoTIFF path and reads values as an array
            band_path = os.path.join(
                patch_folder_path_s2, patch_name_s2 + '_' + band_name + '.tif')
            if GDAL_EXISTED:
                band_ds = gdal.Open(band_path,  gdal.GA_ReadOnly)
                raster_band = band_ds.GetRasterBand(1)
                band_data = raster_band.ReadAsArray()
                bands[band_name] = np.array(band_data)
            elif RASTERIO_EXISTED:
                band_ds = rasterio.open(band_path)
                band_data = np.array(band_ds.read(1))
                bands[band_name] = np.array(band_data)
        
        original_labels_multi_hot = np.zeros(
            len(label_indices['original_labels'].keys()), dtype=int)
        BigEarthNet_19_labels_multi_hot = np.zeros(len(label_conversion),dtype=int)
        patch_json_path = os.path.join(
            patch_folder_path_s1, patch_name_s1 + '_labels_metadata.json')  #原代码:patch_folder_path_s1, patch_name + '_labels_metadata.json')

        with open(patch_json_path, 'r') as f:  #原代码: with open(patch_json_path, 'rb') as f 删除b
            patch_json = json.load(f)

        original_labels = patch_json['labels']
        for label in original_labels:
            original_labels_multi_hot[label_indices['original_labels'][label]] = 1

        for i in range(len(label_conversion)):
            BigEarthNet_19_labels_multi_hot[i] = (
                    np.sum(original_labels_multi_hot[label_conversion[i]]) > 0
                ).astype(int)

        BigEarthNet_19_labels = []
        for i in np.where(BigEarthNet_19_labels_multi_hot == 1)[0]:
            BigEarthNet_19_labels.append(BigEarthNet_19_label_idx[i])

        if UPDATE_JSON:
            patch_json['BigEarthNet_19_labels'] = BigEarthNet_19_labels
            with open(patch_json_path, 'w') as f:#原代码: with open(patch_json_path, 'wb') as f: 删除b
                json.dump(patch_json, f)

        example = prep_example(
            bands, 
            BigEarthNet_19_labels,
            BigEarthNet_19_labels_multi_hot,
            patch_name_s1, 
            patch_name_s2
        )
        TFRecord_writer.write(example.SerializeToString())
        progress_bar.update(patch_idx)

def prep_tf_record_files(root_folder_s1, root_folder_s2, out_folder, split_names, patch_names_list, label_indices, GDAL_EXISTED, RASTERIO_EXISTED, UPDATE_JSON):
    try:
        writer_list = []
        for split_name in split_names:
            writer = tf.io.TFRecordWriter(os.path.join(out_folder, split_name + '.tfrecord'))  # tensorflow2没有python_io库
            # writer = tf.python_io.TFRecordWriter(os.path.join(out_folder, split_name + '.tfrecord'))
            writer_list.append(writer)
            # writer_list.append(
            #        tf.python_io.TFRecordWriter(os.path.join(
            #            out_folder, split_name + '.tfrecord'))
            #    )
    except:
        print('ERROR: TFRecord writer is not able to write files')
        exit()

    for split_idx in range(len(patch_names_list)):
        print('INFO: creating the split of', split_names[split_idx], 'is started')
        create_split(
            root_folder_s1, 
            root_folder_s2,
            patch_names_list[split_idx], 
            writer_list[split_idx],
            label_indices,
            GDAL_EXISTED, 
            RASTERIO_EXISTED, 
            UPDATE_JSON
            )
        writer_list[split_idx].close()
        

修改完毕后cd到项目文件夹,在命令行输入指令

python prep_splits_19_classes.py -r1 root_folder_s1 -r2 root_folder_s2 -o out_folder -n "./BigEarthNet-MM_19-classes_models-master/splits/test.csv" "./BigEarthNet-MM_19-classes_models-master/splits/train.csv" "./BigEarthNet-MM_19-classes_models-master/splits/val.csv" --update_json -l tensorflow

然后正常运行了,out_folder中会建立三个TFrecord文件。但由于数据量太大了,需要跑很长时间,所以笔者在考虑换数据集实验了,耐不住时间了。


2024/2/26补: 后面发现导入GDAL库代码为from osgeo import gdal

 如有错误的地方,还请指正!有疑惑的同学可以在评论区留言,无论能否解决,笔者看到一定回复。

评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

我还能再学!

学到就是赚到!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值