【已解决】CLIP 的 textencoder 部分 .pt 转化 onnx 模型没有输入节点 [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Invalid

遇到该类问题,可以通过以下方法来确认是否与本文的问题一样:

  1. 看模型定义的输入(1)通过 onnx_model.graph.input 来看,(2)通过 ONNX 可视化工具 Netron 来看,均没有 input
  2. 通过 onnxruntime 来加载模型用于推理,出现报错 [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Invalid

解决方案:修改 export 时的输入为 random ,而非确定的值,否则会被视为常量、被 onnx 折叠起来。

1.看模型定义的输入

(1)通过 onnx 的 .graph.input 可以找到转化时定义的模型输入,示例代码如下

import onnx
import onnxruntime as ort

# 加载 ONNX 模型
onnx_model_path = "/path/to/clip_txt1.onnx"
onnx_model = onnx.load(onnx_model_path)

# 打印模型输入定义列表
print("ONNX 模型的输入定义:")
for input in onnx_model.graph.input:
    print(f"Name: {input.name}")
    print(f"Type: {input.type}")
    print(f"Shape: {[dim.dim_value for dim in input.type.tensor_type.shape.dim]}")

# 使用 onnxruntime 加载模型并打印输入定义
try:
    ort_session = ort.InferenceSession(onnx_model_path)
    print("\nONNX Runtime 模型的输入定义:")
    for input in ort_session.get_inputs():
        print(f"Name: {input.name}")
        print(f"Type: {input.type}")
        print(f"Shape: {input.shape}")
except Exception as e:
    print(f"Error loading model with onnxruntime: {e}")

# 验证模型以确保其正确性
try:
    onnx.checker.check_model(onnx_model)
    print("The model is valid.")
except onnx.checker.ValidationError as e:
    print(f"The model is invalid: {e}")

如果没有输入,可能会得到以下结果

ONNX 模型的输入定义:
2024-05-23 09:51:04.537391508 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Add node '/Add_79'
2024-05-23 09:51:04.715214096 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.0/attn/Sqrt'
2024-05-23 09:51:04.715499051 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.1/attn/Sqrt'
2024-05-23 09:51:04.715710237 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.2/attn/Sqrt'
2024-05-23 09:51:04.715895645 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.3/attn/Sqrt'
2024-05-23 09:51:04.716076274 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.4/attn/Sqrt'
2024-05-23 09:51:04.716256863 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.5/attn/Sqrt'
2024-05-23 09:51:04.716439756 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.6/attn/Sqrt'
2024-05-23 09:51:04.716620936 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.7/attn/Sqrt'
2024-05-23 09:51:04.716800192 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.8/attn/Sqrt'
2024-05-23 09:51:04.716978226 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.9/attn/Sqrt'
2024-05-23 09:51:04.717158474 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.10/attn/Sqrt'
2024-05-23 09:51:04.717340426 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.11/attn/Sqrt'
2024-05-23 09:51:04.737839929 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.11/attn/Sqrt'
2024-05-23 09:51:04.737862642 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.10/attn/Sqrt'
2024-05-23 09:51:04.737872220 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.9/attn/Sqrt'
2024-05-23 09:51:04.737881878 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.8/attn/Sqrt'
2024-05-23 09:51:04.737890714 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.7/attn/Sqrt'
2024-05-23 09:51:04.737899811 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.6/attn/Sqrt'
2024-05-23 09:51:04.737908648 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.5/attn/Sqrt'
2024-05-23 09:51:04.737917344 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.4/attn/Sqrt'
2024-05-23 09:51:04.737926081 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.3/attn/Sqrt'
2024-05-23 09:51:04.737934697 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.2/attn/Sqrt'
2024-05-23 09:51:04.737943253 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.1/attn/Sqrt'
2024-05-23 09:51:04.737951839 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.0/attn/Sqrt'
2024-05-23 09:51:04.738515476 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Add node '/Add_79'

ONNX Runtime 模型的输入定义:
The model is valid.

尽管通过了 onnx.checker.check_model 但仍然有问题,即输入不存在,(2)通过 ONNX 可视化工具 Netron 查看时,最上方的是 Add 而非 Input
在这里插入图片描述
或者可以看到 Model Properties 中无输入部分(只存在 Outputs),如下图
在这里插入图片描述

2.通过 onnxruntime 加载模型并推理

通过 onnxruntime 来推理模型的代码示例如下,也可参考官方示例

import onnxruntime
from PIL import Image
from typing import AnyStr
import numpy as np
import torch
import torch.nn as nn
from typing import Union, List, Tuple
from functools import partial
from torchvision import transforms
import os
import sys
from pytorch_svgrender.painter.clipfont import (imagenet_templates, compose_text_with_templates, Painter,
                                                PainterOptimizer)
import clip

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # 第一个输入以及预处理
prompt = 'Green Cthulhu'
template_text = compose_text_with_templates(prompt, imagenet_templates)
tokenize_fn = partial(clip.tokenize, context_length=77)
tokens1 = tokenize_fn(template_text).numpy()

# runtime 实例化
ort_session = onnxruntime.InferenceSession('/path/to/clip_txt1.onnx', providers=["CUDAExecutionProvider"])
# onnxruntime.InferenceSession 用于获取一个 ONNX Runtime 推理器,其参数是用于推理的 ONNX 模型文件。

# 第一个 runtime 输入
ort_inputs = {'input': tokens1}

# 第一次 runtime 推理
text_features = ort_session.run(['output'], ort_inputs)[0]

# 第二个输入以及预处理
source = "A photo"
template_source = compose_text_with_templates(source, imagenet_templates)
tokens2 = tokenize_fn(template_text).numpy()   

# 第二个 runtime 输入         
ort_inputs = {'input': tokens2}

# 第二次 runtime 推理            
text_source = ort_session.run(['output'], ort_inputs)[0]

出现以下报错

/path/miniconda3/envs/svgrender/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:69: UserWarning: Specified provider 'CUDAExecutionProvider' is not in available provider names.Available providers: 'AzureExecutionProvider, CPUExecutionProvider'
  warnings.warn(
2024-05-23 10:03:56.768496802 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Add node '/Add_79'
2024-05-23 10:03:56.943615158 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.0/attn/Sqrt'
2024-05-23 10:03:56.943943475 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.1/attn/Sqrt'
2024-05-23 10:03:56.944162473 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.2/attn/Sqrt'
2024-05-23 10:03:56.944398975 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.3/attn/Sqrt'
2024-05-23 10:03:56.944606811 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.4/attn/Sqrt'
2024-05-23 10:03:56.944810019 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.5/attn/Sqrt'
2024-05-23 10:03:56.945021934 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.6/attn/Sqrt'
2024-05-23 10:03:56.945225472 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.7/attn/Sqrt'
2024-05-23 10:03:56.945441505 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.8/attn/Sqrt'
2024-05-23 10:03:56.945642549 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.9/attn/Sqrt'
2024-05-23 10:03:56.945846608 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.10/attn/Sqrt'
2024-05-23 10:03:56.946049756 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.11/attn/Sqrt'
2024-05-23 10:03:56.962879497 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.11/attn/Sqrt'
2024-05-23 10:03:56.962898172 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.10/attn/Sqrt'
2024-05-23 10:03:56.962914894 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.9/attn/Sqrt'
2024-05-23 10:03:56.962932037 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.8/attn/Sqrt'
2024-05-23 10:03:56.962950051 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.7/attn/Sqrt'
2024-05-23 10:03:56.962967385 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.6/attn/Sqrt'
2024-05-23 10:03:56.962985429 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.5/attn/Sqrt'
2024-05-23 10:03:56.963002401 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.4/attn/Sqrt'
2024-05-23 10:03:56.963019965 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.3/attn/Sqrt'
2024-05-23 10:03:56.963037498 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.2/attn/Sqrt'
2024-05-23 10:03:56.963054170 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.1/attn/Sqrt'
2024-05-23 10:03:56.963072215 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Sqrt node '/transformer/resblocks/resblocks.0/attn/Sqrt'
2024-05-23 10:03:56.963655278 [W:onnxruntime:, constant_folding.cc:269 ApplyImpl] Could not find a CPU kernel and hence can't constant fold Add node '/Add_79'
Traceback (most recent call last):
  File "/path/PyTorch-SVGRender/test_onnx_txt.py", line 32, in <module>
    text_features = ort_session.run(['output'], ort_inputs)[0]
  File "/path/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 220, in run
    return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Invalid input name: input

🔥解决方案

错误输入1,直接是文本 x_txt = [xxxx],导致的问题就如上文所示,并没有输入节点,被识别为常量。

错误输入2:x_txt = torch.randn(79, 77),导致的问题如下,需要 int

RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)

错误输入3:x_txt = torch.randint(0, 49408, (79, 77))

其中为什么是 49408?因为token 范围 (0 到 49408)是 CLIP 模型的词汇表大小。CLIP 的词汇表大小通常为 49408。每个 token ID 的范围从 0 到 49408,表示不同的词或标记。

虽然该输入可以正常转化,但 onnxruntime 中推理时会出现以下报错,详解在文章ONNXRuntimeError 9 NOT_IMPLEMENTED,默认时 int64,直接改为 int32 即可!

[ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for ArgMax(13) node with name '/ArgMax

所以,最终的正确输入应该是:x_txt = torch.randint(0, 49408, (79, 77)).to(torch.int32)

完整 CLIP 文本编码部分的 onnx 转化代码


class TxtModelWrapper(nn.Module):

    def __init__(self,
                 clip_model_name: str,
                 download_root: str = None,
                 device: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
                 jit: bool = False,
                 # additional params
                 visual_score: bool = False,
                 feats_loss_type: str = None,
                 feats_loss_weights: List[float] = None,
                 fc_loss_weight: float = None,
                 context_length: int = 77):
        super().__init__()

        import clip  # local import

        # check model info
        self.clip_model_name = clip_model_name
        self.device = device
        self.available_models = clip.available_models()
        assert clip_model_name in self.available_models, f"A model backbone: {clip_model_name} that does not exist"

        # load CLIP
        self.model, self.preprocess = clip.load(clip_model_name, device=self.device, jit=jit, download_root=download_root)
        self.model.eval()

        # load tokenize
        self.tokenize_fn = partial(clip.tokenize, context_length=context_length)
    
    def forward(self, tokens, norm: bool = True):
        tokens = tokens.to(self.device).to(torch.int32)  
        txt_features = self.model.encode_text(tokens)
        if norm:
            text_features = txt_features.mean(axis=0, keepdim=True)
            text_features_norm = text_features / text_features.norm(dim=-1, keepdim=True)
            return text_features_norm
        return txt_features

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 如果GPU设备可以,则用GPU;否则用CPU设备

clip_txt = TxtModelWrapper("ViT-B/32", device= device)

# 最终正确的输入
x_txt = torch.randint(0, 49408, (79, 77)).to(torch.int32)  

with torch.no_grad(): 
    torch.onnx.export(
        clip_txt, 
        x_txt, 
        "clip_txt1.onnx", 
        opset_version=17, 
        input_names=['input'], 
        output_names=['output'],
        dynamic_axes={'input' : {0 : 'batch_size',
                                 1 : 'seq'},
                      'output' : {0 : 'batch_size',
                                  1 : 'seq'}},
        do_constant_folding=False)
  • 13
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值