有几个原因:
- pytorch版本与cuda版本不一致
- 没有安装cudatoolkit
检查cuda版本
查看显卡版本:
nvidia-smi
可见我们显卡驱动是cuda11.6
可以看下本地有没安装错误版本的pytorch
conda list pytorch
重装pytorch
到pytorch官网查看对应你cuda版本的pytorch安装命令,执行安装。其中-c xxx
的意思是channel,应该是选择镜像源,如果下载不顺利,或者想从本地配置的源下载,可以把这两条-c
命令删除。
笔者当时用pip和conda各自删除一遍pytorch,再重新安装就好了。
运行代码查看是否安装成功
可以在python脚本中运行
print(torch.__version__)
print(torch.cuda.is_available())
得到输出如下时,就安装成功了
1.12.1
True
如果类似于如下输出,就没安装成功,需要排查原因。
1.12.1+cpu
False