Jetson显卡上运行Qwen2-1.5b模型时报错“RuntimeError triu_tril_cuda_template not implemented for ‘BFloat16’”
问题描述:
CUDA_DEVICE="cuda:0"
model_name_or_path = '/qwen2-1.5b-instruct'
Tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,use_fast=False)
Model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map=CUDA_DEVICE, torch_dtype="auto")
print("-------------------self.Model.device:",self.Model.device)
prompt="hi"
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
#
input_ids = Tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True)
model_inputs = Tokenizer([input_ids], return_tensors="pt").to(CUDA_DEVICE)
generated_ids = Model.generate(model_inputs.input_ids,top_p=0.2,max_new_tokens=512)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = Tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(response)
运行至Model.generate()时,报错“RuntimeError triu_tril_cuda_template not implemented for ‘BFloat16’”
解决方法:
经过查找资料后发现是transformers版本与torch版本不匹配问题,重装后解决,我的torch是Jetson显卡的版本。
以前的:
torch 2.0.0a0+ec3941ad.nv23.2
torchaudio 0.13.1+b90d798
torchvision 0.14.1a0+5e8e2f1
tornado 6.2
tqdm 4.66.1
traitlets 5.9.0
transformers 4.44.2
重新安装transformers==4.40.0,
pip install transformers==4.40.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
torch 2.0.0a0+ec3941ad.nv23.2
torchaudio 0.13.1+b90d798
torchvision 0.14.1a0+5e8e2f1
tornado 6.2
tqdm 4.66.1
traitlets 5.9.0
transformers 4.40.0
思考:网上有说transformers安装为4.43.0版本好用的,这个需要根据自己安装环境来确定版本,以上是我的解决的方法,以供大家参考。
参考资料
《Qwen2报错——RuntimeError: “triu_tril_cuda_template“ not implemented for ‘BFloat16‘》
https://blog.csdn.net/qq_35357274/article/details/141157962