海洋预报大模型“羲和”

目录

模型特点与优势

文件下载

运行试验

数据预处理

模型调试与使用

模型结果可视化


模型特点与优势

  1. 高分辨率:“羲和”是全球1/12°高分辨率海洋预报大模型,能够提供更精细的海洋环境预报信息。
  2. 高性能:通过与法国PSY4、加拿大GIOPS等国际主流业务预报系统的对比评测,“羲和”在海水温度剖面、盐度剖面、海表流场、海平面高度等评测要素上表现出更好的性能。
  3. 推理速度快:“羲和”在单块GPU卡上平均仅需3.6秒即可完成1到10天的全球海洋环境逐日预报,比数值预报业务系统快1000倍以上。
  4. 预报时效长:“羲和”可生成长达60天的海流预测结果,且准确率优于世界先进的海洋环境业务预报系统PSY4的10天预报结果。
  5. 架构新颖:“羲和”在Swin-Transformer架构基础上,引入了Ocean-Land Mask机制和组传播机制,使模型能够专注于学习海洋环境的内在机理,并降低计算复杂度。

关于“羲和”模型的详细介绍和研究成果于2024年2月5日在Arxiv上在线公布,论文标题为《XiHe: A Data-Driven Model for Global Ocean Eddy-Resolving Forecasting》

本文是基于GitHub上的代码实现:Ocean-Intelligent-Forecasting / XiHe-GlobalOceanForecasting

(此模型仅给出预训练模型,未给出训练代码,可直接使用)

文件下载

不要在GitHub网页中code下载,文件夹为空,应从readme中链接下载。

读者可以直接从此下载  -> 百度网盘下载

按GitHub文章中要求,文件目录结构如下

├── root
│   ├── input_data
│   │   ├── input_surface_data
│   │   │    ├── input_surface_20190101.npy
│   │   ├── input_deep_data
│   │   │    ├── input_deep_20190101.npy
│   ├── output_data
│   ├── models
│   |   ├── xihe_1to22_1day.onnx
│   |   ├── ...
│   |   ├── xihe_1to22_10day.onnx
│   |   ├── xihe_23to33_1day.onnx
│   |   ├── ...
│   |   ├── xihe_23to33_10day.onnx
│   ├── src
│   |   ├── data.yaml
│   |   ├── normalize_mean_50.npz
│   |   ├── normalize_std_50.npz
│   |   ├── mask_surface.npy
│   |   ├── mask_deep.npy
│   |   ├── mercator_lat.npy
│   |   ├── mercator_lon.npy
│   |   ├── data_process.py
│   |   ├── inference.py 

环境配置可以参考:Python虚拟环境

将下载的环境配置文件pycdo移动到anaconda对应的文件位置,如下:

在pycharm中配置环境

我用的社区版,专业版可能略有不同

(注:pycdo文件中没有python.exe文件)

运行试验

input_data含有预处理输入数据的示例,对应于 2019/01/01 的每日GLORYS12再分析数据的均值

可直接运行inference.py文件,

也可以在终端输入以下命令:

python src/inference.py --lead_day 7 --save_path output_data

 运行完成后,output_data内生成6个 nc 文件,如下:

(关于如何将这六个文件内容可视化,我在下面模型结果可视化会提到。)

数据预处理

模型中的输入数据来源包括GHRSSTERA5GLORYS12 Reanalysis,按文章意思,只使用GHRSSTGLORYS12 Reanalysis同样可行,以下只讨论二者结合的数据预处理。读者可自行研究三项数据预处理。

为方便数据预处理,我新建了一个dispose文件夹,专门存放相关内容,结构如下:

以2019.01.01为例,从链接中下载相关数据,为 nc 文件。

示例中两种情况形状分别为 (1,52,2041,4320)和(1,48,2041,4320)  的 NumPy 数组,2041 和 4320 的尺寸都表示沿纬度和经度的大小,其中数值范围分别为 [-80,90] 度和 [-180,180] 度,间距为 1/12 度。对于每个 2041x4320 的切片,数据格式与从 GLORYS12 再分析官网下载的文件完全相同。

原始GHRSST数据位于 1/20° 网格 (3600 × 7200)上,为达到与示例相同格式,因此需要从 GHRSST网格到 GLORYS 网格进行插值。

在Github问题讨论,有插值的相关论述,以供参考。

这里使用了线性插值,但运行时间仍然过长,以下代码仅供参考,并不确保正确有效。

(作者尝试过并行处理,但由于熟练度低,效果并不理想,这里就不展示相关代码了) 


#此为interpolation.py文件

import xarray as xr
import numpy as np
from scipy.interpolate import griddata
import time

# 获取开始时间
start = time.perf_counter()

# 加载OSTIA和GLORYS数据集
ostia_ds = xr.open_dataset('20190101120000-UKMO-L4_GHRSST-SSTfnd-OSTIA-GLOB-v02.0-fv02.0.nc')
glorys_ds = xr.open_dataset('mercatorglorys12v1_gl12_mean_20190101_R20190102.nc')

# 提取OSTIA的SST数据和坐标
ostia_sst = ostia_ds['analysed_sst'].squeeze()  # 假设只有一个时间步
ostia_lons = ostia_ds['lon'].values
ostia_lats = ostia_ds['lat'].values

# 提取GLORYS的坐标
glorys_lons = glorys_ds['longitude'].values
glorys_lats = glorys_ds['latitude'].values

# 创建与GLORYS网格相匹配的经纬度网格
glorys_lon2d, glorys_lat2d = np.meshgrid(glorys_lons, glorys_lats, indexing='ij')

# 将OSTIA的经纬度展平以便插值
ostia_lon2d, ostia_lat2d = np.meshgrid(ostia_lons, ostia_lats, indexing='ij')
ostia_lon2d_flat = ostia_lon2d.flatten()
ostia_lat2d_flat = ostia_lat2d.flatten()
ostia_sst_flat = ostia_sst.values.flatten()

# 为了避免内存问题,我们可以考虑分块插值
# 假设我们将数据分成多个小块进行处理
block_size = 1000  # 可以根据实际情况调整这个大小
num_blocks = (glorys_lon2d.size // block_size) + 1

# 计算运行时间
end = time.perf_counter()
runTime = end - start
print('开始迭代运行插值\n总迭代次数为:', block_size,'\n当前运行时间:', runTime, "秒","\n\n")

glorys_sst_interp_blocks = []

for i in range(num_blocks):
    block_start = i * block_size
    block_end = min((i + 1) * block_size, glorys_lon2d.size)

    # 使用griddata进行插值(这里使用linear方法以提高速度)
    block_interp = griddata(
        (ostia_lon2d_flat, ostia_lat2d_flat),
        ostia_sst_flat,
        (glorys_lon2d.flatten()[block_start:block_end], glorys_lat2d.flatten()[block_start:block_end]),
        method='linear'                #linear (线性插值)(精度较低)   cubic(立方插值)(精度较高)
    )
    glorys_sst_interp_blocks.append(block_interp)
    print('迭代次数: ', i)
    end = time.perf_counter()
    runTime = end - start
    print("\t运行时间: ", runTime/60, "分钟")

# 将所有插值块合并成一个完整的数组
glorys_sst_interp = np.concatenate(glorys_sst_interp_blocks)
glorys_sst_interp = glorys_sst_interp.reshape(glorys_lat2d.shape)

# 将插值后的SST数据添加到GLORYS数据集中
glorys_ds['sst'] = (('latitude', 'longitude'), glorys_sst_interp)

# 保存插值后的数据集
glorys_ds.to_netcdf('path_to_interpolated_dataset.nc')  #建议替换为辨识度更高的的NetCDF文件名

插值完成后是一个 nc 文件,按示例格式,需要数据重构,生成形状分别为 (1,52,2041,4320)和(1,48,2041,4320)  的 NumPy 数组的2个 npy 文件。


#此为revise.py文件


import netCDF4 as nc
import numpy as np

date = input('请输入日期:\n')
# 打开原始的NetCDF文件
nc_file = 'path_to_interpolated_dataset.nc'  # 替换为你的NetCDF文件名
dataset = nc.Dataset(nc_file, 'r')
print(dataset)
# 提取所有需要的变量
zos = dataset.variables['zos'][:]
uo = dataset.variables['uo'][:]
vo = dataset.variables['vo'][:]
sst = dataset.variables['sst'][:]

# 初始化surface和deep数据的数组
input_surface_data = np.zeros((1, 52, 2041, 4320))
input_deep_data = np.zeros((1, 48, 2041, 4320))

# 填充surface数据
input_surface_data[0, 0, :, :] = zos[0, 0, :, :]
input_surface_data[0, 1, :, :] = uo[0, 0, :, :]
input_surface_data[0, 2, :, :] = vo[0, 0, :, :]
input_surface_data[0, 3, :, :] = sst[0, 0, :, :]

for i in range(22):
    thetao_key = f'thetao_{i}'
    so_key = f'so_{i}'
    uo_key = f'uo_{i}'
    vo_key = f'vo_{i}'

    thetao = dataset.variables[thetao_key][:]
    so = dataset.variables[so_key][:]
    uo = dataset.variables[uo_key][:]
    vo = dataset.variables[vo_key][:]

    input_surface_data[0, 4 + 4 * i, :, :] = thetao[0, 0, :, :]
    input_surface_data[0, 5 + 4 * i, :, :] = so[0, 0, :, :]
    input_surface_data[0, 6 + 4 * i, :, :] = uo[0, 0, :, :]
    input_surface_data[0, 7 + 4 * i, :, :] = vo[0, 0, :, :]

# 填充deep数据
input_deep_data[0, 0, :, :] = zos[0, 0, :, :]
input_deep_data[0, 1, :, :] = uo[0, 0, :, :]
input_deep_data[0, 2, :, :] = vo[0, 0, :, :]
input_deep_data[0, 3, :, :] = sst[0, 0, :, :]

for i in range(22, 33):
    thetao_key = f'thetao_{i}'
    so_key = f'so_{i}'
    uo_key = f'uo_{i}'
    vo_key = f'vo_{i}'

    thetao = dataset.variables[thetao_key][:]
    so = dataset.variables[so_key][:]
    uo = dataset.variables[uo_key][:]
    vo = dataset.variables[vo_key][:]

    input_deep_data[0, 4 + 4 * (i - 22), :, :] = thetao[0, 0, :, :]
    input_deep_data[0, 5 + 4 * (i - 22), :, :] = so[0, 0, :, :]
    input_deep_data[0, 6 + 4 * (i - 22), :, :] = uo[0, 0, :, :]
    input_deep_data[0, 7 + 4 * (i - 22), :, :] = vo[0, 0, :, :]

# 保存为NumPy数组文件

# 指定保存地址和文件名(包含变量名称)
save_path_1 = '../input_data/input_surface_data/input_surface_'
save_path_2 = '../input_data/input_deep_data/input_deep_'
filename_1 = f"{save_path_1}{date}.npy"
filename_2 = f"{save_path_1}{date}.npy"
# 保存数组
np.save(filename_1, input_surface_data)
np.save(filename_2, input_deep_data)

# 关闭NetCDF文件
dataset.close()

模型调试与使用

1.  直接运行inference.py 文件 

在不考虑修改文件名的情况下,inference.py 文件可调试的内容仅为以下两点,“预报天数” 和 “数据日期” 。(此为浅层次,更深层次作者也不懂)

2. 输入终端命令运行程序

仍需按照 “直接运行inference.py 文件” 的情况进行调试,唯一不同的是在代码中已有默认值时,以终端命令为执行标准。 

python src/inference.py --lead_day 7 --save_path output_data

模型结果可视化

inference.py 程序运行完成后,会生成以下六个 nc 文件,为了可视化结果,我额外新建了一个show.py 文件 ,结构如下:

show.py代码如下:

import netCDF4 as nc
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr

# 打开nc文件
nc_file = '20190101_so.nc'  # 将nc_file替换为你的nc文件路径
print(xr.open_dataset(nc_file))
nc_data = nc.Dataset(nc_file, 'r')

# 查看nc文件中包含的变量
print("Variables in the nc file:", nc_data.variables.keys())

# 选择要读取的变量
variable_name = 'so'

longitude = nc_data.variables['longitude'][:]
latitude = nc_data.variables['latitude'][:]
depth = nc_data.variables['depth'][:]  #在 zos 和 sst 的nc文件中,不包含depth变量,使用者可以考虑将其隐藏
time = nc_data.variables['time'][:]
variable = nc_data.variables[variable_name][:]

plt.figure(figsize=(10, 5))

plt.imshow(variable[0, 0, :, :], cmap='jet', origin='lower',
           extent=(longitude.min(), longitude.max(), latitude.min(), latitude.max()))
plt.colorbar(label='Index')
plt.xlabel('Longitude')
plt.ylabel('Latitude')
plt.title('Ocean')
plt.show()


# 以下提供两种额外可视化效果,可供参考
# 计算最大强度投影
# mip = np.max(variable, axis=1)
# plt.imshow(mip[0, :, :],cmap='jet',origin='lower')
# plt.colorbar(label='Index')
# plt.xlabel('Longitude')
# plt.ylabel('Latitude')
# plt.title('Ocean')
# plt.show()

# 计算平均强度投影
# aip = np.mean(variable, axis=1)
# plt.imshow(aip[0, :, :],cmap='jet',origin='lower')
# plt.colorbar(label='Index')
# plt.xlabel('Longitude')
# plt.ylabel('Latitude')
# plt.title('Ocean')
# plt.show()

现在讲解一下上述代码的部分含义以及使用方法。

1.so,thetao,vo,uo的 nc 文件中包含time,depth,latitude,longitude,一共四个变量,而sst,zos的 nc 文件中只包含time,latitude,longitude三个变量,因此在读取数据时,这行代码会读取错误,建议将其注释掉,或者读者可自行修改代码。

2.variable[0, 0, :, :]属于可调试部分,示例代码中使用的是so.nc文件,因此会有4个变量,第一个0代表time,有且仅有 “ 0 ” 一个数据,第二个0代表depth,可替换从0到22中任意数字,分别代表不同深度。两个“ :” 表示的是经度和纬度。

(如果对图像的颜色映射不满意,可自行修改cmap参数)

结果展示

  • 19
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值