PyTorch量化报错后端不匹配

环境:PyTorch-1.7.1
错误描述:使用PyTorch Quantization包进行量化感知训练(QAT)时,最后一步convert报错:

Traceback (most recent call last):
  File "train.py", line 136, in <module>
    main()
  File "train.py", line 126, in main
    quantized_model = torch.quantization.convert(model.eval(), inplace=False)
  File "/data/yyl/anaconda3/envs/pytorch-1.7.1/lib/python3.7/site-packages/torch/quantization/quantize.py", line 414, in convert
    _convert(module, mapping, inplace=True)
  File "/data/yyl/anaconda3/envs/pytorch-1.7.1/lib/python3.7/site-packages/torch/quantization/quantize.py", line 458, in _convert
    _convert(mod, mapping, inplace=True)
  File "/data/yyl/anaconda3/envs/pytorch-1.7.1/lib/python3.7/site-packages/torch/quantization/quantize.py", line 459, in _convert
    reassign[name] = swap_module(mod, mapping)
  File "/data/yyl/anaconda3/envs/pytorch-1.7.1/lib/python3.7/site-packages/torch/quantization/quantize.py", line 485, in swap_module
    new_mod = mapping[type(mod)].from_float(mod)
  File "/data/yyl/anaconda3/envs/pytorch-1.7.1/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py", line 368, in from_float
    return cls.get_qconv(mod, activation_post_process, weight_post_process)
  File "/data/yyl/anaconda3/envs/pytorch-1.7.1/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py", line 157, in get_qconv
    qweight = _quantize_weight(mod.weight.float(), weight_post_process)
  File "/data/yyl/anaconda3/envs/pytorch-1.7.1/lib/python3.7/site-packages/torch/nn/quantized/modules/utils.py", line 16, in _quantize_weight
    wt_scale.to(torch.double), wt_zp.to(torch.int64), wt_axis, torch.qint8)
RuntimeError: Could not run 'aten::quantize_per_channel' with arguments from the 'CUDA' backend. 'aten::quantize_per_channel' is only available for these backends: [CPU, BackendSelect, Named, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, Autocast, Batched, VmapMode].

CPU: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/build/aten/src/ATen/CPUType.cpp:2127 [kernel]
BackendSelect: fallthrough registered at /opt/conda/conda-bld/pytorch_1607370141920/work/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Named: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
AutogradOther: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
AutogradCPU: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
AutogradCUDA: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
AutogradXLA: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
AutogradPrivateUse1: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
AutogradPrivateUse2: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
AutogradPrivateUse3: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
Tracer: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/torch/csrc/autograd/generated/TraceType_2.cpp:9654 [kernel]
Autocast: fallthrough registered at /opt/conda/conda-bld/pytorch_1607370141920/work/aten/src/ATen/autocast_mode.cpp:254 [backend fallback]
Batched: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/aten/src/ATen/BatchingRegistrations.cpp:511 [backend fallback]
VmapMode: fallthrough registered at /opt/conda/conda-bld/pytorch_1607370141920/work/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]

(pytorch-1.7.1) ➜  CIFAR-10 python train.py
Files already downloaded and verified
Files already downloaded and verified
/data/yyl/anaconda3/envs/pytorch-1.7.1/lib/python3.7/site-packages/torch/quantization/observer.py:121: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.
  reduce_range will be deprecated in a future release of PyTorch."

解决方案:我的模型训练过程在cuda上完成,而量化支持的是cpu后端,因此需要先将模型转到cpu上再量化:

quantized_model = torch.quantization.convert(model.cpu().eval(), inplace=False)
这是我的报错,请问它反映了哪些问题?你又是如何锁定问题的重点并考虑出解决方案的: 025-03-04 11:39:26,407 - INFO - download models from model hub: ms Downloading Model to directory: C:\Users\QwQ\.cache\modelscope\hub\models\damo\speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch 2025-03-04 11:39:27,033 - modelscope - WARNING - Model revision not specified, use revision: v2.0.9 2025-03-04 11:39:29,017 - WARNING - enable bias encoder sampling and contextual training 2025-03-04 11:39:29,738 - INFO - Loading pretrained params from C:\Users\QwQ\.cache\modelscope\hub\models\damo\speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch\model.pt Downloading Model to directory: C:\Users\QwQ\.cache\modelscope\hub\models\damo\speech_fsmn_vad_zh-cn-16k-common-pytorch 2025-03-04 11:39:31,712 - INFO - Building VAD model. 2025-03-04 11:39:31,713 - INFO - download models from model hub: ms 2025-03-04 11:39:33,038 - modelscope - WARNING - Model revision not specified, use revision: v2.0.4 2025-03-04 11:39:33,662 - INFO - Loading pretrained params from C:\Users\QwQ\.cache\modelscope\hub\models\damo\speech_fsmn_vad_zh-cn-16k-common-pytorch\model.pt Downloading Model to directory: C:\Users\QwQ\.cache\modelscope\hub\models\damo\punc_ct-transformer_cn-en-common-vocab471067-large 2025-03-04 11:39:33,681 - INFO - Building punc model. 2025-03-04 11:39:33,681 - INFO - download models from model hub: ms 2025-03-04 11:39:34,053 - modelscope - WARNING - Model revision not specified, use revision: v2.0.4 2025-03-04 11:39:37,655 - INFO - Loading pretrained params from C:\Users\QwQ\.cache\modelscope\hub\models\damo\punc_ct-transformer_cn-en-common-vocab471067-large\model.pt 2025-03-04 11:39:40,014 - INFO - Loading LLM model... The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead. Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results
03-08
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值