天气预报大模型GraphCast使用以及GPU加速环境配置

前言:

这是本人在使用github上下载的GraphCast0.1版本程序时遇到的问题以及解决途径。

参考内容见文章末尾。

使用环境:

WSL-Ubuntu/Cuda 12.2/Cudnn 8.9.4/jax0.4.28/jaxlib-0.4.28+cuda12.cudnn89/python 3.10

值得注意的是,建议GraphCast的目录有以下组成,方便新手上路理解:具体的数据下载步骤我会在下边展示

GraphCast目录下
├── code
│   ├── GraphCast-from-Ground-Zero
│       ├──graphcast
│       ├──tree
│       ├──wrapt
│       ├──graphcast_demo.ipynb
│       ├──README.md
│       ├──setup.py
│       ├──...
├── data
│   ├── dataset(需要输入的数据)
│       ├──dataset-source-era5_date-2022-01-01_res-1.0_levels-13_steps-01.nc
│       ├──dataset-source-era5_date-2022-01-01_res-1.0_levels-13_steps-04.nc
│       ├──dataset-source-era5_date-2022-01-01_res-1.0_levels-13_steps-12.nc
│       ├──...
│   ├── params(预训练模型)
│       ├──params-GraphCast - ERA5 1979-2017 - resolution 0.25 - pressure levels 37 - mesh 2to6 - precipitation input and output.npz
│       ├──params-GraphCast_small - ERA5 1979-2015 - resolution 1.0 - pressure levels 13 - mesh 2to5 - precipitation input and output.npz
│       ├──...
│   ├── stats(规范化数据)
│       ├──stats-mean_by_level.nc
│       ├──...

如我的文件放置:


一、Graphcast程序下载:

        对于普通版本的GraphCast的下载,参见https://github.com/google-deepmind/graphcast即可。这里推荐使用https://github.com/sfsun67/GraphCast-from-Ground-Zero。以下所有官方说明均以google官方为主,这里边建议的原因是因为graphcast_demo_Chinese.ipynb中部分注释为中文(graphcast_demo_Chinese.ipynb同官方graphcast_demo.ipynb),对新手上路有所帮助。

        在code文件夹中运行以下命令:

git clone https://github.com/sfsun67/GraphCast-from-Ground-Zero

二、python环境配置 :

1.python选择3.10版本或者3.11:

        这个要求是在graphcast-0.1/setup.py中提到的:

'''
setup中说明如下,此外还说明了系统应选择Linux
    classifiers=[
        "Development Status :: 3 - Alpha",
        "Intended Audience :: Science/Research",
        "License :: OSI Approved :: Apache Software License",
        "Operating System :: POSIX :: Linux",
        "Programming Language :: Python :: 3.10",
        "Programming Language :: Python :: 3.11",
        "Topic :: Scientific/Engineering :: Artificial Intelligence",
        "Topic :: Scientific/Engineering :: Atmospheric Science",
        "Topic :: Scientific/Engineering :: Physics",
'''
conda create -n graphcast python=3.10
conda activate graphcast #进入新下载的环境

 2. graphcast的安装

         在这一步中,仅需要使用python运行setup程序即可,他会给用户下载基础的依赖库,但他其中的jax等内容在下边步骤会更改。

python setup.py install
conda install -y -c conda-forge ipywidgets#如果使用vscode还需要自己安装这个库
#此外按照步骤来进行,最后运行graphcast_demo.ipynb时缺哪个库就直接下载就可以

3.GPU环境:

        首先使用以下命令,查看GPU驱动(由于WSL-Ubuntu中GPU依赖于Windows,所以自带有GPU驱动)

nvidia-smi

        下图中Driver Version代表GPU驱动版本,CUDA Version代表可以接受的最高CUDA版本

(由于我的GPU版本原先为537.58,仅支持CUDA12.2所以本次过程我使用了CUDA12.2,下边我会解释为什么我会更新GPU驱动)

3.1 CUDA下载 :

        在CUDA Toolkit Archive | NVIDIA Developer中选择所需要的版本点击进入,并选择好配置后,下边会显示下载命令以及安装命令,安装时会有卡顿,后边的全部接收就可以,下载以后使用 nvcc --version检查:

wget https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run
sudo sh cuda_12.2.0_535.54.03_linux.run

#随后使用 nvcc --version查看下载版本是否一致
(graphcast) Somnr1@Somnr1:~/Graphcast/graphcast-0.1$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Jun_13_19:16:58_PDT_2023
Cuda compilation tools, release 12.2, V12.2.91
Build cuda_12.2.r12.2/compiler.32965470_0

3.2 Cudnn下载:        

        对应的Cudnn下载需要与CUDA版本对齐,可以在官网https://developer.nvidia.com/rdp/cudnn-archive找到所需版本:

        对下载(.tar.xz文件)文件解压(tar Jxvf ):

tar -Jxvf cudnn-linux-x86_64-8.9.7.29_cuda12-archive.tar.xz

        进入下载后的文件进行以下步骤:

sudo cp include/cudnn*.h /usr/local/cuda/include
sudo cp lib/libcudnn* /usr/local/cuda/lib64
sudo chmod a+r /usr/local/cuda/include/cudnn*.h /usr/local/cuda/lib64/libcudnn*

        检查,返回说明8.9.7版本Cudnn安装完毕:

cat /usr/local/cuda/include/cudnn_version.h | grep CUDNN_MAJOR -A 2
#返回内容:
(graphcast) Somnr1@Somnr1:~/InstallationPackage$ cat /usr/local/cuda/include/cudnn_version.h | grep CUDNN_MAJOR -A 2
#define CUDNN_MAJOR 8
#define CUDNN_MINOR 9
#define CUDNN_PATCHLEVEL 7
--
#define CUDNN_VERSION (CUDNN_MAJOR * 1000 + CUDNN_MINOR * 100 + CUDNN_PATCHLEVEL)

/* cannot use constexpr here since this is a C-only file */

3.3 JAX安装:

       由于我们setup graphcast时下载了jax,所以先卸载:

pip uninstall jax jaxlib

        jax以及jaxlib的下载要严格按照CUDA以及Cudnn的版本下载,否则会出现以下示例等问题:

E1007 16:53:46.694790   32209 cuda_dnn.cc:502] There was an error before creating cudnn handle (500): cudaErrorSymbolNotFound : named symbol not found

        下载流程:先下载jaxlib在下载jax:

        在 https://storage.googleapis.com/jax-releases/jax_cuda_releases.html网站中我们可以找到我们所需要的jaxlib对应CUDA以及Cudnn版本的下载文件:

            将文件下载下来后pip安装即可,对应的jax版本与jaxlib相同即可:

#安装指定版本的 JAXlib
pip install https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.28+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl

#安装对应版本的 JAX
pip install jax==0.4.28

检查,结果如下代表安装成功:

(graphcast) Somnr1@Somnr1:~/InstallationPackage$ python
Python 3.10.14 (main, May  6 2024, 19:42:50) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> print(jax.devices())
[cuda(id=0)]
>>> jax.random.PRNGKey(0)
Array([0, 0], dtype=uint32)

3.4下载jax以及jaxlib对应版本时与graphcast部分库冲突解决方法:

我遇到了三个库冲突,对应jax版本下载对应版本即可

pip uninstall flax
pip uninstall dm-haiku
pip uninstall orbax-checkpoint

pip install flax==0.6.10
pip install dm-haiku==0.0.11
pip install orbax-checkpoint==0.6.0

4. GPU驱动 :

我在运行时提示:

>>> jax.random.PRNGKey(0)
2024-10-08 13:44:37.994428: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.6.77). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages

这说明我们的GPU版本过低,在WINDOWS中下载 GEFORCE EXPERENCE更新驱动即可(不要下载CUDA向下兼容器,他无法从根源解决问题,只是不会报错了)

但有一个问题:上述提示说我的ptxas CUDA version 为12.6.77,但我使用ptxas --version查看时提示版本为12.2,这个问题并不知道具体原因(不过更新驱动可以解决版本导致的无法并行运算这个问题)

(graphcast) Somnr1@Somnr1:~$ ptxas --version
ptxas: NVIDIA (R) Ptx optimizing assembler
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Jun_13_19:13:58_PDT_2023
Cuda compilation tools, release 12.2, V12.2.91
Build cuda_12.2.r12.2/compiler.32965470_0


三、Graphcast所需文件下载

训练所需要输入数据、预测模型以及规范化数据可以从Google Cloud Bucket中下载:https://console.cloud.google.com/storage/browser/dm_graphcast 

1.Google Cloud Bucket数据下载

        1. gsutil下载以及安装:

                https://cloud.google.com/storage/docs/gsutil_install?hl=zh-cn#linux

        2. gsutil使用:

#查看文件夹下的内容
gsutil ls gs:///dm_graphcast

#先串行下载
gsutil cp -R gs:///dm_graphcast

2.输入数据下载

        其中输入数据可以使用Google Cloud Bucket做测试,往后可以自由选择想要测试的数据,注意不同数据要和不同模型相对应。

        这里给EC数据下载方式:CDS下载

        1. 账号注册:

                ECMWF | Advancing global NWP through international collaboration

        2. 在home目录下创建.cdsapirc文件内容为:

                注意冒号和内容间不要有空格(官方教程有所以特加注释)

url:https://cds-beta.climate.copernicus.eu/api
key:********-****-****-****-************

3. python环境中安装CDS客户端:

pip install 'cdsapi>=0.7.2'

4. 数据下载:官方例子

  import cdsapi

  client = cdsapi.Client()
  
  dataset = 'reanalysis-era5-pressure-levels'
  request = {
      'product_type': ['reanalysis'],
      'variable': ['geopotential'],
      'year': ['2024'],
      'month': ['03'],
      'day': ['01'],
      'time': ['13:00'],
      'pressure_level': ['1000'],
      'data_format': 'grib',
  }
  target = 'download.grib'
  
  client.retrieve(dataset, request, target)

5.如果出现RuntimeError: 403 Client Error: required licences not accepted

就去官网对应文件下载网页,找到对应条款登录签约即可


四、Graphcast运行 

1. 打开graphcast_demo.ipynb文件

以上环境配置都注释,从加在库开始即可

2. 步骤:加载数据并初始化模型 时模型区别

有两种模型参数配置方法:random或者checkpoint,选择后者则说明直接使用graphcast提供的三种模型

前者就是应用该流程程序,但是Mesh Size、GNN Message Steps、Latent Size、Pressure Levels由自己决定

他们的不同点在于:

1. 网格大小指定了地球的内部图形表示。较小的网格将运行更快,但输出将更差。网格大小不影响模型的参数数量。

2. 分辨率和压力级别的数量必须匹配数据。较低的分辨率和较少的级别会运行得更快。数据分辨率仅影响编码器/解码器。

3. 我们的所有模型都预测降水。然而,ERA5包含降水,而HRES不包含。我们标记为 "ERA5" 的模型将降水作为输入,并期望以ERA5数据作为输入,而标记为 "ERA5-HRES" 的模型不以降水作为输入,并专门训练以HRES-fc0作为输入

3. 步骤:载入示例数据 时文件筛选机制

分辨率要求:model_config.resolution;

压力层数要求:len(task_config.pressure_levels)l;

输入变量和数据来源要求:如果 task_config.input_variables 中包含 "total_precipitation_6hr"l,那么文件的数据来源(source)必须是 "era5" 或 "fake",反之必须是 "hres" 或 "fake"

4. 其他

其他步骤直接运行即可

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值