Pytorch 环境
准备
查看 Nvidia 驱动版本
在安装Pytorch之前,需要确定NVIDIA驱动和CUDA版本
在命令行输入 nvidia-smi
查看环境 NVIDIA 驱动的版本
!注意:525.60.13 是驱动版本,CUDA Version:12.0 表示此时支持的最高 CUDA 版本
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.60.13 Driver Version: 525.60.13 CUDA Version: 12.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... Off | 00000000:01:00.0 On | Off |
| 30% 54C P2 252W / 450W | 11470MiB / 24564MiB | 61% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
CUDA 和 NVIDIA 驱动版本的对照表
https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html
如果驱动版本太低,可以升级 NVIDIA 驱动
安装
Pytorch 官网: https://pytorch.org
在 Install Pytorch 部分,选择 pytorch Build、OS、Package、Language、Compute Platform
Compute Platform 请选择驱动支持的CUDA版本
对应的会在Run this Command 部分生成下载所需的命令行指令
这部分只支持最新版本的 pytorch 安装
如pytorch 2.0.1, Linux, pip, Python, CUDA11.8
Command: pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
如果需要安装旧版本,点击install previous versions of Pytorch,选择自己需要的版本
如 v1.13.0 版本,使用 pip 安装:
# ROCM 5.2 (Linux only)
pip install torch==1.13.0+rocm5.2 torchvision==0.14.0+rocm5.2 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/rocm5.2
# CUDA 11.6
pip install torch==1.13.0+cu116 torchvision==0.14.0+cu116 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu116
# CUDA 11.7
pip install torch==1.13.0+cu117 torchvision==0.14.0+cu117 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu117
# CPU only
pip install torch==1.13.0+cpu torchvision==0.14.0+cpu torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cpu
测试
在 Python Console 中,或者创建 .py 文件运行
import torch
print(torch.__version__) # 查看 torch 版本
# '2.0.0+cu118'
print(torch.version.cuda) # 查看 cuda 版本(这是 Pytorch 所安装和调用的 CUDA 版本,与电脑实际的 CUDA 版本不一样)
# '11.8'
print(torch.cuda.is_available()) # 查看cuda是否可用
# 'True'