Pytorch 报错 AssertionError: If capturable=False, state_steps should not be CUDA tensors.

Pytorch 报错 AssertionError: If capturable=False, state_steps should not be CUDA tensors.

  1. 具体报错内容如下:

      File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/optim/adam.py", line 213, in adam
      File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/optim/lr_scheduler.py", line 65, in wrapper
        return wrapped(*args, **kwargs)
      File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/optim/optimizer.py", line 109, in wrapper
        func(params,
      File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/optim/adam.py", line 255, in _single_tensor_adam
        return func(*args, **kwargs)
      File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
        return func(*args, **kwargs)
          File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/optim/adam.py", line 157, in step
    assert not step_t.is_cuda, "If capturable=False, state_steps should not be CUDA tensors."
    AssertionError: If capturable=False, state_steps should not be CUDA tensors.
        adam(params_with_grad,
      File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/optim/adam.py", line 213, in adam
        func(params,
      File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/optim/adam.py", line 255, in _single_tensor_adam
        assert not step_t.is_cuda, "If capturable=False, state_steps should not be CUDA tensors."
    AssertionError: If capturable=False, state_steps should not be CUDA tensors.
    ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 4681) of binary: /opt/anaconda3/envs/xxx/bin/python
    Traceback (most recent call last):
      File "/opt/anaconda3/envs/xxx/lib/python3.10/runpy.py", line 196, in _run_module_as_main
        return _run_code(code, main_globals, None,
      File "/opt/anaconda3/envs/xxx/lib/python3.10/runpy.py", line 86, in _run_code
        exec(code, run_globals)
      File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/distributed/launch.py", line 193, in <module>
        main()
      File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/distributed/launch.py", line 189, in main
        launch(args)
      File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/distributed/launch.py", line 174, in launch
        run(args)
      File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/distributed/run.py", line 752, in run
        elastic_launch(
      File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 131, in __call__
        return launch_agent(self._config, self._entrypoint, list(args))
      File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 245, in launch_agent
        raise ChildFailedError(
    torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
    
  2. 可能的原因
    1.12.0 版本 torch 新引入的参数 capturable有关。

  3. 运行环境包版本信息:

    PyTorch: 1.12.0+cu116
    TorchVision: 0.13.0+cu116
    
  4. 解决方案

    • 加载 checkpoint 后设置参数 capturableTrue
      optim.param_groups[0]['capturable'] = True
      
    • 执行如下命令,安装 1.12 版本的 torch 及与之对应的 torchvision
      pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116	
      

    注:上述命令安装对应 CUDA 11.6torchtorchvision,亦可安装其他版本的 torch,诸如 1.101.13 等。


Ref

[1] https://github.com/pytorch/pytorch/issues/80809#issuecomment-1173481031

  • 5
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

_湘江夜话_

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值