pytorch,cuda,cudatoolkit,driver四者关系如下:
- 系统的Nvidia Driver决定着系统最高可以支持什么版本的cuda和cudatoolkit,Nvidia Driver是向下兼容的,详情如下(见Nvidia Driver和Cuda对应关系)
- cuda和cudatoolkit不同,前者说的是系统安装的cuda,它是由Nvidia官方提供的(/usr/local/cuda就是系统安装的cuda的软链接),这与我们要安装的pytorch几乎没有什么关系。后者是anaconda官方提供的用来build pytorch的一个工具包,它是Nvidia所提供的cuda的一个子集。
- pytorch和cudatoolkit版本并不是一一对应的关系,一个pytorch版本可以有多个cudatoolkit版本与之对应。例如1.5.1版本的pytorch,既可以使用9.2版本的cudatoolkit,也可以使用10.2版本的cudatoolkit。
pytorch安装的坑
只指定pytorch版本来安装不一定是能work的,例如执行conda install pytorch=X.X.X -c pytorch
时,conda会自动为你选择合适版本的cudatoolkit。但conda只能保证你的pytorch和cudatoolkit版本一定是对应的,但并不能保证pytorch可以正常使用,因为系统的Nvidia Driver有可能不支持你所安装的cudatoolkit版本。
所以,除非你对你的Nvidia driver版本很有自信,否则,还是先查看系统Nvidia Driver的版本,并在上方图表中查询最高支持的cudatoolkit版本,然后指定cudatoolkit版本来安装pytorch吧。例如系统的Nvidia Driver版本为440.33,查询到最高支持cudatoolkit版本为10.2,则可以使用conda install pytorch cudatoolkit=10.2 -c pytorch
命令安装pytorch。
当然,如果你对pytorch版本有特别的要求,你可以同时指定pytorch和cudatoolkit的版本。如果这两个版本不能兼容,系统会报错,例如:
conda install pytorch=1.5.1 cudatoolkit=9.0 -c pytorch
conda会告诉你:
Solving environment: failed
UnsatisfiableError: The following specifications were found to be in conflict:
- cudatoolkit=9.0 -> __cuda[version='>=9.0']
- pytorch=1.5.1
Use "conda info <package>" to see the dependencies for each package.
如果你的Nvidia Driver最高支持的cudatoolkit版本为9.0,而你又一定要用1.5.1版本的pytorch,那么你必须升级系统的Nvidia Driver版本。