在对 CLIP 模型的文本 encoder 模块从 .pt 到 .onnx 的转化过程中,遇到报错
Error loading model with onnxruntime: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for ArgMax(13) node with name '/ArgMax'
遇到的问题大概率是由于ONNX Runtime不支持模型中的ArgMax操作的 int64
类型输入而导致的。解决方法是将ArgMax的输入类型更改为 float32 或 int32
。
可以按照以下步骤进行操作:
- 找到模型中的ArgMax操作的定义。
- 将ArgMax操作的输入类型更改为float32或int32。
- 保存修改后的模型。
- 使用修改后的模型进行推理。
另外,也可以尝试在模型外部实现ArgMax操作,并将结果传递给网络。
在进行任何修改之前,建议先备份原始模型,以防意外情况发生。
# 原代码
x_txt = torch.randint(0, 49408, (79, 77))
# 修改后代码
x_txt = torch.randint(0, 49408, (79, 77)).to(torch.int32)
这里的范围 0 到 49408 的 token 输入通常来自于 CLIP 模型的文本 tokenization 过程。CLIP 模型中的文本输入需要经过 tokenization,以便将每个词转换为一个唯一的 token ID。在上述的例子中:
- 文本长度 (77): 这个长度通常是 CLIP 模型的最大上下文长度,即一个句子中的最大 token 数量。在 partial(clip.tokenize, context_length=77) 中,context_length 被设置为 77。
- batch size (79): 这个值是输入 batch 中文本句子的数量。
- token 范围 (0 到 49408): 这是 CLIP 模型的词汇表大小。CLIP 的词汇表大小通常为 49408。每个 token ID 的范围从 0 到 49408,表示不同的词或标记。
参考方案:
https://github.com/microsoft/onnxruntime/issues/9760