Pytorch 报错 AssertionError: If capturable=False, state_steps should not be CUDA tensors.
-
具体报错内容如下:
File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/optim/adam.py", line 213, in adam File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/optim/lr_scheduler.py", line 65, in wrapper return wrapped(*args, **kwargs) File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/optim/optimizer.py", line 109, in wrapper func(params, File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/optim/adam.py", line 255, in _single_tensor_adam return func(*args, **kwargs) File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(*args, **kwargs) File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/optim/adam.py", line 157, in step assert not step_t.is_cuda, "If capturable=False, state_steps should not be CUDA tensors." AssertionError: If capturable=False, state_steps should not be CUDA tensors. adam(params_with_grad, File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/optim/adam.py", line 213, in adam func(params, File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/optim/adam.py", line 255, in _single_tensor_adam assert not step_t.is_cuda, "If capturable=False, state_steps should not be CUDA tensors." AssertionError: If capturable=False, state_steps should not be CUDA tensors. ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 4681) of binary: /opt/anaconda3/envs/xxx/bin/python Traceback (most recent call last): File "/opt/anaconda3/envs/xxx/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/opt/anaconda3/envs/xxx/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/distributed/launch.py", line 193, in <module> main() File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/distributed/launch.py", line 189, in main launch(args) File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/distributed/launch.py", line 174, in launch run(args) File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/distributed/run.py", line 752, in run elastic_launch( File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 131, in __call__ return launch_agent(self._config, self._entrypoint, list(args)) File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 245, in launch_agent raise ChildFailedError( torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
-
可能的原因
与1.12.0
版本torch
新引入的参数capturable
有关。 -
运行环境包版本信息:
PyTorch: 1.12.0+cu116 TorchVision: 0.13.0+cu116
-
解决方案
- 加载
checkpoint
后设置参数capturable
为True
:optim.param_groups[0]['capturable'] = True
- 执行如下命令,安装
1.12
版本的torch
及与之对应的torchvision
:pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116
注:上述命令安装对应
CUDA 11.6
的torch
及torchvision
,亦可安装其他版本的torch
,诸如1.10
、1.13
等。 - 加载
Ref
[1] https://github.com/pytorch/pytorch/issues/80809#issuecomment-1173481031