我的虚拟环境:
torch==1.7.0+cu110
numpy==1.21.6
transformers==2.11.0
urllib3==1.26.16
安装流程:
1. 下载apex项目
地址:https://github.com/NVIDIA/apex
Git命令下载或安装包下载均可
2. 安装命令
python setup.py install
显示下载成功:
但我在使用的时候报错:AttributeError: module 'torch.distributed' has no attribute '_all_gather_base'
,然后在https://github.com/NVIDIA/apex/issues/1532里面看到可能是因为版本的原因,下载22.04-dev版本的apex项目或许可以。
下载地址:https://github.com/NVIDIA/apex/tree/22.04-dev?search=1
然后就试了一下,成功解决!!