JAX
库安装后只能看到cpu
设备;- 主要问题是
cuda
和cudnn
版本匹配问题;
- github一堆issues,类似这个https://github.com/google/jax/issues/971,
直接从装https://storage.googleapis.com/jax-releases
下载轮子文件安装,pip install --upgrade -f https://xxxxxxxx
; 均失败;
问题描述:
安装完jax
和jaxlib
之后,
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
只显示cpu设备,但安装的torch和tensorflow都可以看到gpu;
经历了一番重复性操作,卸载换版本、再看看cuda、安装卸载、换版本、安装、pip安装、下载wheel安装。。。。最后总算对了。
前面都是废话,正文从下面开始
- 查看显卡信息,确认
cuda
版本最大为11.3
(nvcc --version
的结果给我10.3)
nvidia-smi
Tue Jul 12 22:26:26 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01 Driver Version: 465.19.01 CUDA Version: 11.3 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... On | 00000000:84:00.0 Off | N/A |
| 41% 34C P8 21W / 260W | 95MiB / 11016MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 1729 G /usr/lib/xorg/Xorg 73MiB |
| 0 N/A N/A 1937 G /usr/bin/gnome-shell 13MiB |
| 0 N/A N/A 2304356 G gnome-control-center 3MiB |
+-----------------------------------------------------------------------------+
- 在
.wheel
文件都google.storage里翻了半天,发现还有带cudnn
信息都轮子文件,还不能安装太旧的jax,会和其他库冲突;
-
去看看自己的
cudnn
版本,在/usr/local/cuda-11.3/include/cudnn_version.h
文件里,
确认是cudnn82
;
> #define CUDNN_MAJOR 8
> #define CUDNN_MINOR 2 -
再对上自己的
python=3.8
;
(base) xxxx:~$ cat /usr/local/cuda-11.3/include/cudnn_version.h
/*
* Copyright 2019 NVIDIA Corporation. All rights reserved.
*
* NOTICE TO LICENSEE:
*
* This source code and/or documentation ("Licensed Deliverables") are
* subject to NVIDIA intellectual property rights under U.S. and
* international Copyright laws.
*
* These Licensed Deliverables contained herein is PROPRIETARY and
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
* conditions of a form of NVIDIA software license agreement by and
* between NVIDIA and Licensee ("License Agreement") or electronically
* accepted by Licensee. Notwithstanding any terms or conditions to
* the contrary in the License Agreement, reproduction or disclosure
* of the Licensed Deliverables to any third party without the express
* written consent of NVIDIA is prohibited.
*
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
* OF THESE LICENSED DELIVERABLES.
*
* U.S. Government End Users. These Licensed Deliverables are a
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
* 1995), consisting of "commercial computer software" and "commercial
* computer software documentation" as such terms are used in 48
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
* U.S. Government End Users acquire the Licensed Deliverables with
* only those rights set forth herein.
*
* Any use of the Licensed Deliverables in individual and commercial
* software must include, in the user documentation and internal
* comments to the code, the above Disclaimer and U.S. Government End
* Users Notice.
*/
/**
* \file: The master cuDNN version file.
*/
#ifndef CUDNN_VERSION_H_
#define CUDNN_VERSION_H_
#define CUDNN_MAJOR 8
#define CUDNN_MINOR 2
#define CUDNN_PATCHLEVEL 0
#define CUDNN_VERSION (CUDNN_MAJOR * 1000 + CUDNN_MINOR * 100 + CUDNN_PATCHLEVEL)
#endif /* CUDNN_VERSION_H */
pip uninstall jax jaxlib
, 再去安装对的版本即可。(切记一定要先卸载!先卸载!再安装!)
pip install --upgrade jax==0.3.14 jaxlib==0.3.14+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Collecting jax==0.3.14
Using cached jax-0.3.14-py3-none-any.whl
Collecting jaxlib==0.3.14+cuda11.cudnn82
Downloading https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.14%2Bcuda11.cudnn82-cp38-none-manylinux2014_x86_64.whl (161.9 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 161.9/161.9 MB 2.6 MB/s eta 0:00:00
Requirement already satisfied: absl-py in ./anaconda3/lib/python3.8/site-packages (from jax==0.3.14) (1.1.0)
Requirement already satisfied: numpy>=1.19 in ./anaconda3/lib/python3.8/site-packages (from jax==0.3.14) (1.23.1)
Requirement already satisfied: scipy>=1.5 in ./anaconda3/lib/python3.8/site-packages (from jax==0.3.14) (1.5.2)
Requirement already satisfied: typing-extensions in ./anaconda3/lib/python3.8/site-packages (from jax==0.3.14) (4.3.0)
Requirement already satisfied: opt-einsum in ./anaconda3/lib/python3.8/site-packages (from jax==0.3.14) (3.3.0)
Requirement already satisfied: etils[epath] in ./anaconda3/lib/python3.8/site-packages (from jax==0.3.14) (0.6.0)
Requirement already satisfied: flatbuffers<3.0,>=1.12 in ./anaconda3/lib/python3.8/site-packages (from jaxlib==0.3.14+cuda11.cudnn82) (1.12)
Requirement already satisfied: zipp in ./anaconda3/lib/python3.8/site-packages (from etils[epath]->jax==0.3.14) (3.4.0)
Requirement already satisfied: importlib_resources in ./anaconda3/lib/python3.8/site-packages (from etils[epath]->jax==0.3.14) (5.1.2)
Installing collected packages: jaxlib, jax
Successfully installed jax-0.3.14 jaxlib-0.3.14+cuda11.cudnn82