【StableDiffusion】秋叶包,安装 tensorRT 各种坑,包含请求超时、Exporting to ONNX failed、Failed to parse ONNX model

总览

1.无法安装
2.报错解决
3.显存小的就别玩 tensorRT 了

一、无法安装

1.使用魔法

我知道你要说,魔法也不行巴拉巴拉
我也很疑惑,最开始使用本地网络和魔法都不行
后来发现是概率问题
使用魔法,然后多试几次(我试了大概7、8次)然后下载成功了

在这里插入图片描述

2.疑难杂症

什么请求超时、服务器无法回应、回应错误之类的巴拉巴拉
直接使用魔法多试几次就好了

二、报错

1.报错情况1

ERROR:root:Exporting to ONNX failed. Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

完整报错内容:

Disabling attention optimization
Exporting boleromixPony_v14 to TensorRT using - Batch Size: 1-1-1
Height: 768-768-768
Width: 768-768-768
Token Count: 75-75-75
Disabling attention optimization
ERROR:root:Exporting to ONNX failed. Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
Building TensorRT engine... This can take a while, please check the progress in the terminal.
Building TensorRT engine for S:\app_AI\stableDiffusion-webui-aki\sd-webui-aki-v4.8\models\Unet-onnx\boleromixPony_v14.onnx: S:\app_AI\stableDiffusion-webui-aki\sd-webui-aki-v4.8\models\Unet-trt\boleromixPony_v14_ad5d6010_cc86_sample=2x4x96x96-timesteps=2-encoder_hidden_states=2x77x2048-y=2x2816.trt
Could not open file S:\app_AI\stableDiffusion-webui-aki\sd-webui-aki-v4.8\models\Unet-onnx\boleromixPony_v14.onnx
Could not open file S:\app_AI\stableDiffusion-webui-aki\sd-webui-aki-v4.8\models\Unet-onnx\boleromixPony_v14.onnx
[E] ModelImporter.cpp:773: Failed to parse ONNX model from file: S:\app_AI\stableDiffusion-webui-aki\sd-webui-aki-v4.8\models\Unet-onnx\boleromixPony_v14.onnx
[!] Failed to parse ONNX model. Does the model file exist and contain a valid ONNX model?
Traceback (most recent call last):
  File "S:\app_AI\stableDiffusion-webui-aki\sd-webui-aki-v4.8\python\lib\site-packages\gradio\routes.py", line 488, in run_predict
    output = await app.get_blocks().process_api(
  File "S:\app_AI\stableDiffusion-webui-aki\sd-webui-aki-v4.8\python\lib\site-packages\gradio\blocks.py", line 1431, in process_api
    result = await self.call_function(
  File "S:\app_AI\stableDiffusion-webui-aki\sd-webui-aki-v4.8\python\lib\site-packages\gradio\blocks.py", line 1103, in call_function
    prediction = await anyio.to_thread.run_sync(
  File "S:\app_AI\stableDiffusion-webui-aki\sd-webui-aki-v4.8\python\lib\site-packages\anyio\to_thread.py", line 31, in run_sync
    return await get_asynclib().run_sync_in_worker_thread(
  File "S:\app_AI\stableDiffusion-webui-aki\sd-webui-aki-v4.8\python\lib\site-packages\anyio\_backends\_asyncio.py", line 937, in run_sync_in_worker_thread
    return await future
  File "S:\app_AI\stableDiffusion-webui-aki\sd-webui-aki-v4.8\python\lib\site-packages\anyio\_backends\_asyncio.py", line 867, in run
    result = context.run(func, *args)
  File "S:\app_AI\stableDiffusion-webui-aki\sd-webui-aki-v4.8\python\lib\site-packages\gradio\utils.py", line 707, in wrapper
    response = f(*args, **kwargs)
  File "S:\app_AI\stableDiffusion-webui-aki\sd-webui-aki-v4.8\extensions\Stable-Diffusion-WebUI-TensorRT\ui_trt.py", line 126, in export_unet_to_trt
    ret = export_trt(
  File "S:\app_AI\stableDiffusion-webui-aki\sd-webui-aki-v4.8\extensions\Stable-Diffusion-WebUI-TensorRT\exporter.py", line 231, in export_trt
    ret = engine.build(
  File "S:\app_AI\stableDiffusion-webui-aki\sd-webui-aki-v4.8\extensions\Stable-Diffusion-WebUI-TensorRT\utilities.py", line 227, in build
    network = network_from_onnx_path(
  File "<string>", line 3, in network_from_onnx_path
  File "S:\app_AI\stableDiffusion-webui-aki\sd-webui-aki-v4.8\python\lib\site-packages\polygraphy\backend\base\loader.py", line 40, in __call__
    return self.call_impl(*args, **kwargs)
  File "S:\app_AI\stableDiffusion-webui-aki\sd-webui-aki-v4.8\python\lib\site-packages\polygraphy\util\util.py", line 710, in wrapped
    return func(*args, **kwargs)
  File "S:\app_AI\stableDiffusion-webui-aki\sd-webui-aki-v4.8\python\lib\site-packages\polygraphy\backend\trt\loader.py", line 247, in call_impl
    trt_util.check_onnx_parser_errors(parser, success)
  File "S:\app_AI\stableDiffusion-webui-aki\sd-webui-aki-v4.8\python\lib\site-packages\polygraphy\backend\trt\util.py", line 88, in check_onnx_parser_errors
    G_LOGGER.critical(
  File "S:\app_AI\stableDiffusion-webui-aki\sd-webui-aki-v4.8\python\lib\site-packages\polygraphy\logger\logger.py", line 605, in critical
    raise ExceptionType(message) from None
polygraphy.exception.exception.PolygraphyException: Failed to parse ONNX model. Does the model file exist and contain a valid ONNX model?
2.解决办法

去这个目录:

X:\XXX\sd-webui-aki-v4.8\extensions\Stable-Diffusion-WebUI-TensorRT

替换这两个文件:

ui_trt.py
exporter.py

替换的代码如下:
ui_trt.py 替换代码(直接复制就好了,别忘了删掉 CSDN 水印):

import os
import gc
import json
import logging
from collections import defaultdict

import torch
from safetensors.torch import save_file
import gradio as gr

from modules.shared import cmd_opts
from modules.ui_components import FormRow
from modules import sd_hijack, sd_models, shared
from modules.ui_common import refresh_symbol
from modules.ui_components import ToolButton

from model_helper import UNetModel
from exporter import export_onnx, export_trt, export_lora
from model_manager import modelmanager, cc_major, TRT_MODEL_DIR
from datastructures import SDVersion, ProfilePrests, ProfileSettings


profile_presets = ProfilePrests()

logging.basicConfig(level=logging.INFO)


def get_context_dim():
    if shared.sd_model.is_sd1:
        return 768
    elif shared.sd_model.is_sd2:
        return 1024
    elif shared.sd_model.is_sdxl:
        return 2048


def is_fp32():
    use_fp32 = False
    if cc_major < 7:
        use_fp32 = True
        print("FP16 has been disabled because your GPU does not support it.")
    return use_fp32


def export_unet_to_trt(
    batch_min,
    batch_opt,
    batch_max,
    height_min,
    height_opt,
    height_max,
    width_min,
    width_opt,
    width_max,
    token_count_min,
    token_count_opt,
    token_count_max,
    force_export,
    static_shapes,
    preset,
):
    sd_hijack.model_hijack.apply_optimizations("None")

    is_xl = shared.sd_model.is_sdxl
    model_name = shared.sd_model.sd_checkpoint_info.model_name

    profile_settings = ProfileSettings(
        batch_min,
        batch_opt,
        batch_max,
        height_min,
        height_opt,
        height_max,
        width_min,
        width_opt,
        width_max,
        token_count_min,
        token_count_opt,
        token_count_max,
    )
    if preset == "Default":
        profile_settings = profile_presets.get_default(is_xl=is_xl)
    use_fp32 = is_fp32()

    print(f"Exporting {model_name} to TensorRT using - {profile_settings}")
    profile_settings.token_to_dim(static_shapes)

    model_hash = shared.sd_model.sd_checkpoint_info.hash
    model_name = shared.sd_model.sd_checkpoint_info.model_name

    onnx_filename, onnx_path = modelmanager.get_onnx_path(model_name)
    timing_cache = modelmanager.get_timing_cache()

    diable_optimizations = is_xl
    embedding_dim = get_context_dim()

    modelobj = UNetModel(
        shared.sd_model.model.diffusion_model,
        embedding_dim,
        text_minlen=profile_settings.t_min,
        is_xl=is_xl,
    )
    modelobj.apply_torch_model()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    profile = modelobj.get_input_profile(profile_settings)
    modelobj.unet.to(device)
    for key, value in profile.items():
        if isinstance(value, torch.Tensor):
            profile[key] = value.to(device)
        elif isinstance(value, tuple):
            profile[key] = tuple(v.to(device) for v in value)

    export_onnx(
        onnx_path,
        modelobj,
        profile_settings,
        diable_optimizations=diable_optimizations,
    )
    gc.collect()
    torch.cuda.empty_cache()

    trt_engine_filename, trt_path = modelmanager.get_trt_path(
        model_name, model_hash, profile, static_shapes
    )

    if not os.path.exists(trt_path) or force_export:
        print(
            "Building TensorRT engine... This can take a while, please check the progress in the terminal."
        )
        gr.Info(
            "Building TensorRT engine... This can take a while, please check the progress in the terminal."
        )
        ret = export_trt(
            trt_path,
            onnx_path,
            timing_cache,
            profile=profile,
            use_fp16=not use_fp32,
        )
        if ret:
            return "## Export Failed due to unknown reason. See shell for more information. \n"

        print("TensorRT engines has been saved to disk.")
        modelmanager.add_entry(
            model_name,
            model_hash,
            profile,
            static_shapes,
            fp32=use_fp32,
            inpaint=True if modelobj.in_channels == 6 else False,
            refit=True,
            vram=0,
            unet_hidden_dim=modelobj.in_channels,
            lora=False,
        )
    else:
        print(
            "TensorRT engine found. Skipping build. You can enable Force Export in the Advanced Settings to force a rebuild if needed."
        )

    gc.collect()
    torch.cuda.empty_cache()

    return "## Exported Successfully \n"


def export_lora_to_trt(lora_name, force_export):
    is_xl = shared.sd_model.is_sdxl

    available_lora_models = get_lora_checkpoints()
    lora_name = lora_name.split(" ")[0]
    lora_model = available_lora_models.get(lora_name, None)
    if lora_model is None:
        return f"## No LoRA model found for {lora_name}"

    version = lora_model.get("version", SDVersion.Unknown)
    if version == SDVersion.Unknown:
        print(
            "LoRA SD version couldm't be determined. Please ensure the correct SD Checkpoint is selected."
        )

    model_name = shared.sd_model.sd_checkpoint_info.model_name
    model_hash = shared.sd_model.sd_checkpoint_info.hash

    if not version.match(shared.sd_model):
        print(
            f"""LoRA SD version ({version}) does not match the current SD version ({model_name}). 
            Please ensure the correct SD Checkpoint is selected."""
        )

    profile_settings = profile_presets.get_default(is_xl=False)
    print(f"Exporting {lora_name} to TensorRT using - {profile_settings}")
    profile_settings.token_to_dim(True)

    onnx_base_filename, onnx_base_path = modelmanager.get_onnx_path(model_name)
    if not os.path.exists(onnx_base_path):
        return f"## Please export the base model ({model_name}) first."

    embedding_dim = get_context_dim()

    modelobj = UNetModel(
        shared.sd_model.model.diffusion_model,
        embedding_dim,
        text_minlen=profile_settings.t_min,
        is_xl=is_xl,
    )
    modelobj.apply_torch_model()

    weights_map_path = modelmanager.get_weights_map_path(model_name)
    if not os.path.exists(weights_map_path):
        modelobj.export_weights_map(onnx_base_path, weights_map_path)

    lora_trt_name = f"{lora_name}.lora"
    lora_trt_path = os.path.join(TRT_MODEL_DIR, lora_trt_name)

    if os.path.exists(lora_trt_path) and not force_export:
        print(
            "TensorRT engine found. Skipping build. You can enable Force Export in the Advanced Settings to force a rebuild if needed."
        )
        return "## Exported Successfully \n"

    profile = modelobj.get_input_profile(profile_settings)
    refit_dict = export_lora(
        modelobj,
        onnx_base_path,
        weights_map_path,
        lora_model["filename"],
        profile_settings,
    )

    save_file(refit_dict, lora_trt_path)


    return "## Exported Successfully \n"


def get_version_from_filename(name):
    if "v1-" in name:
        return "1.5"
    elif "v2-" in name:
        return "2.1"
    elif "xl" in name:
        return "xl-1.0"
    else:
        return "Unknown"


def get_lora_checkpoints():
    available_lora_models = {}
    allowed_extensions = ["pt", "ckpt", "safetensors"]
    candidates = [
        p
        for p in os.listdir(cmd_opts.lora_dir)
        if p.split(".")[-1] in allowed_extensions
    ]

    for filename in candidates:
        metadata = {}
        name, ext = os.path.splitext(filename)
        config_file = os.path.join(cmd_opts.lora_dir, name + ".json")

        if ext == ".safetensors":
            metadata = sd_models.read_metadata_from_safetensors(
                os.path.join(cmd_opts.lora_dir, filename)
            )
        else:
            print(
                """LoRA {} is not a safetensor. This might cause issues when exporting to TensorRT.
                   Please ensure that the correct base model is selected when exporting.""".format(
                    name
                )
            )

        base_model = metadata.get("ss_sd_model_name", "Unknown")
        if os.path.exists(config_file):
            with open(config_file, "r") as f:
                config = json.load(f)
            try:
                version = SDVersion.from_str(config["sd version"])
            except:
                version = SDVersion.Unknown

        else:
            version = SDVersion.Unknown
            print(
                "No config file found for {}. You can generate it in the LoRA tab.".format(
                    name
                )
            )

        available_lora_models[name] = {
            "filename": filename,
            "version": version,
            "base_model": base_model,
        }
    return available_lora_models


def get_valid_lora_checkpoints():
    available_lora_models = get_lora_checkpoints()
    return [f"{k} ({v['version']})" for k, v in available_lora_models.items()]


def diable_export(version):
    if version == "Default":
        return (
            gr.update(visible=False),
            gr.update(visible=True),
            gr.update(visible=False),
        )
    else:
        return (
            gr.update(visible=True),
            gr.update(visible=False),
            gr.update(visible=True),
        )


def disable_lora_export(lora):
    if lora is None:
        return gr.update(visible=False)
    else:
        return gr.update(visible=True)


def diable_visibility(hide):
    num_outputs = 8
    out = [gr.update(visible=not hide) for _ in range(num_outputs)]
    return out


def engine_profile_card():
    def get_md_table(
        h_min,
        h_opt,
        h_max,
        w_min,
        w_opt,
        w_max,
        b_min,
        b_opt,
        b_max,
        t_min,
        t_opt,
        t_max,
    ):
        md_table = (
            "|             	|   Min   	|   Opt   	|   Max   	| \n"
            "|-------------	|:-------:	|:-------:	|:-------:	| \n"
            "| Height      	| {h_min} 	| {h_opt} 	| {h_max} 	| \n"
            "| Width       	| {w_min} 	| {w_opt} 	| {w_max} 	| \n"
            "| Batch Size  	| {b_min} 	| {b_opt} 	| {b_max} 	| \n"
            "| Text-length 	| {t_min} 	| {t_opt} 	| {t_max} 	| \n"
        )
        return md_table.format(
            h_min=h_min,
            h_opt=h_opt,
            h_max=h_max,
            w_min=w_min,
            w_opt=w_opt,
            w_max=w_max,
            b_min=b_min,
            b_opt=b_opt,
            b_max=b_max,
            t_min=t_min,
            t_opt=t_opt,
            t_max=t_max,
        )

    available_models = modelmanager.available_models()

    model_md = defaultdict(list)
    loras_md = {}
    for base_model, models in available_models.items():
        for i, m in enumerate(models):
            # if m["config"].lora:
            #     loras_md[base_model] = m.get("base_model", None)
            #     continue

            s_min, s_opt, s_max = m["config"].profile.get(
                "sample", [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]
            )
            t_min, t_opt, t_max = m["config"].profile.get(
                "encoder_hidden_states", [[0, 0, 0], [0, 0, 0], [0, 0, 0]]
            )
            profile_table = get_md_table(
                s_min[2] * 8,
                s_opt[2] * 8,
                s_max[2] * 8,
                s_min[3] * 8,
                s_opt[3] * 8,
                s_max[3] * 8,
                max(s_min[0] // 2, 1),
                max(s_opt[0] // 2, 1),
                max(s_max[0] // 2, 1),
                (t_min[1] // 77) * 75,
                (t_opt[1] // 77) * 75,
                (t_max[1] // 77) * 75,
            )

            model_md[base_model].append(profile_table)

    available_loras = modelmanager.available_loras()
    for lora, path in available_loras.items():
        loras_md[f"{lora}"] = ""

    return model_md, loras_md


def on_ui_tabs():
    with gr.Blocks(analytics_enabled=False) as trt_interface:
        with gr.Row(equal_height=True):
            with gr.Column(variant="panel"):
                with gr.Tabs(elem_id="trt_tabs"):
                    with gr.Tab(label="TensorRT Exporter"):
                        gr.Markdown(
                            value="# TensorRT Exporter",
                        )

                        default_vals = profile_presets.get_default(is_xl=False)
                        version = gr.Dropdown(
                            label="Preset",
                            choices=profile_presets.get_choices(),
                            elem_id="sd_version",
                            default="Default",
                            value="Default",
                        )

                        with gr.Accordion(
                            "Advanced Settings", open=False, visible=False
                        ) as advanced_settings:
                            with FormRow(
                                elem_classes="checkboxes-row", variant="compact"
                            ):
                                static_shapes = gr.Checkbox(
                                    label="Use static shapes.",
                                    value=False,
                                    elem_id="trt_static_shapes",
                                )

                            with gr.Column(elem_id="trt_batch"):
                                trt_min_batch = gr.Slider(
                                    minimum=1,
                                    maximum=16,
                                    step=1,
                                    label="Min batch-size",
                                    value=default_vals.bs_min,
                                    elem_id="trt_min_batch",
                                )

                                trt_opt_batch = gr.Slider(
                                    minimum=1,
                                    maximum=16,
                                    step=1,
                                    label="Optimal batch-size",
                                    value=default_vals.bs_opt,
                                    elem_id="trt_opt_batch",
                                )
                                trt_max_batch = gr.Slider(
                                    minimum=1,
                                    maximum=16,
                                    step=1,
                                    label="Max batch-size",
                                    value=default_vals.bs_min,
                                    elem_id="trt_max_batch",
                                )

                            with gr.Column(elem_id="trt_height"):
                                trt_height_min = gr.Slider(
                                    minimum=256,
                                    maximum=4096,
                                    step=64,
                                    label="Min height",
                                    value=default_vals.h_min,
                                    elem_id="trt_min_height",
                                )
                                trt_height_opt = gr.Slider(
                                    minimum=256,
                                    maximum=4096,
                                    step=64,
                                    label="Optimal height",
                                    value=default_vals.h_opt,
                                    elem_id="trt_opt_height",
                                )
                                trt_height_max = gr.Slider(
                                    minimum=256,
                                    maximum=4096,
                                    step=64,
                                    label="Max height",
                                    value=default_vals.h_max,
                                    elem_id="trt_max_height",
                                )

                            with gr.Column(elem_id="trt_width"):
                                trt_width_min = gr.Slider(
                                    minimum=256,
                                    maximum=4096,
                                    step=64,
                                    label="Min width",
                                    value=default_vals.w_min,
                                    elem_id="trt_min_width",
                                )
                                trt_width_opt = gr.Slider(
                                    minimum=256,
                                    maximum=4096,
                                    step=64,
                                    label="Optimal width",
                                    value=default_vals.w_opt,
                                    elem_id="trt_opt_width",
                                )
                                trt_width_max = gr.Slider(
                                    minimum=256,
                                    maximum=4096,
                                    step=64,
                                    label="Max width",
                                    value=default_vals.w_max,
                                    elem_id="trt_max_width",
                                )

                            with gr.Column(elem_id="trt_token_count"):
                                trt_token_count_min = gr.Slider(
                                    minimum=75,
                                    maximum=750,
                                    step=75,
                                    label="Min prompt token count",
                                    value=default_vals.t_min,
                                    elem_id="trt_opt_token_count_min",
                                )
                                trt_token_count_opt = gr.Slider(
                                    minimum=75,
                                    maximum=750,
                                    step=75,
                                    label="Optimal prompt token count",
                                    value=default_vals.t_opt,
                                    elem_id="trt_opt_token_count_opt",
                                )
                                trt_token_count_max = gr.Slider(
                                    minimum=75,
                                    maximum=750,
                                    step=75,
                                    label="Max prompt token count",
                                    value=default_vals.t_max,
                                    elem_id="trt_opt_token_count_max",
                                )

                            with FormRow(
                                elem_classes="checkboxes-row", variant="compact"
                            ):
                                force_rebuild = gr.Checkbox(
                                    label="Force Rebuild.",
                                    value=False,
                                    elem_id="trt_force_rebuild",
                                )

                        button_export_unet = gr.Button(
                            value="Export Engine",
                            variant="primary",
                            elem_id="trt_export_unet",
                            visible=False,
                        )

                        button_export_default_unet = gr.Button(
                            value="Export Default Engine",
                            variant="primary",
                            elem_id="trt_export_default_unet",
                            visible=True,
                        )

                        version.change(
                            profile_presets.get_settings_from_version,
                            version,
                            [
                                trt_min_batch,
                                trt_opt_batch,
                                trt_max_batch,
                                trt_height_min,
                                trt_height_opt,
                                trt_height_max,
                                trt_width_min,
                                trt_width_opt,
                                trt_width_max,
                                trt_token_count_min,
                                trt_token_count_opt,
                                trt_token_count_max,
                                static_shapes,
                            ],
                        )
                        version.change(
                            diable_export,
                            version,
                            [
                                button_export_unet,
                                button_export_default_unet,
                                advanced_settings,
                            ],
                        )

                        static_shapes.change(
                            diable_visibility,
                            static_shapes,
                            [
                                trt_min_batch,
                                trt_max_batch,
                                trt_height_min,
                                trt_height_max,
                                trt_width_min,
                                trt_width_max,
                                trt_token_count_min,
                                trt_token_count_max,
                            ],
                        )

                    with gr.Tab(label="TensorRT LoRA"):
                        gr.Markdown("# Apply LoRA checkpoint to TensorRT model")
                        lora_refresh_button = gr.Button(
                            value="Refresh",
                            variant="primary",
                            elem_id="trt_lora_refresh",
                        )

                        trt_lora_dropdown = gr.Dropdown(
                            choices=get_valid_lora_checkpoints(),
                            elem_id="lora_model",
                            label="LoRA Model",
                            default=None,
                        )

                        with FormRow(elem_classes="checkboxes-row", variant="compact"):
                            trt_lora_force_rebuild = gr.Checkbox(
                                label="Force Rebuild.",
                                value=False,
                                elem_id="trt_lora_force_rebuild",
                            )

                        button_export_lora_unet = gr.Button(
                            value="Convert to TensorRT",
                            variant="primary",
                            elem_id="trt_lora_export_unet",
                            visible=False,
                        )

                        lora_refresh_button.click(
                            get_valid_lora_checkpoints,
                            None,
                            trt_lora_dropdown,
                        )
                        trt_lora_dropdown.change(
                            disable_lora_export,
                            trt_lora_dropdown,
                            button_export_lora_unet,
                        )

            with gr.Column(variant="panel"):
                with open(
                    os.path.join(os.path.dirname(os.path.abspath(__file__)), "info.md"),
                    "r",
                    encoding="utf-8",
                ) as f:
                    trt_info = gr.Markdown(elem_id="trt_info", value=f.read())

        with gr.Row(equal_height=False):
            with gr.Accordion("Output", open=True):
                trt_result = gr.Markdown(elem_id="trt_result", value="")

        def get_trt_profiles_markdown():
            profiles_md_string = ""
            engine_cards, lora_cards = engine_profile_card()
            for model, profiles in engine_cards.items():
                profiles_md_string += f"<details><summary>{model} ({len(profiles)} Profiles)</summary>\n\n"
                for i, profile in enumerate(profiles):
                    profiles_md_string += f"#### Profile {i} \n{profile}\n\n"
                profiles_md_string += "</details>\n"
            profiles_md_string += "</details>\n"

            profiles_md_string += "\n --- \n ## LoRA Profiles \n"
            for model, details in lora_cards.items():
                profiles_md_string += f"<details><summary>{model}</summary>\n\n"
                profiles_md_string += details
                profiles_md_string += "</details>\n"
            return profiles_md_string

        with gr.Column(variant="panel"):
            with gr.Row(equal_height=True, variant="compact"):
                button_refresh_profiles = ToolButton(
                    value=refresh_symbol, elem_id="trt_refresh_profiles", visible=True
                )
                profile_header_md = gr.Markdown(
                    value=f"## Available TensorRT Engine Profiles"
                )
            with gr.Row(equal_height=True):
                trt_profiles_markdown = gr.Markdown(
                    elem_id=f"trt_profiles_markdown", value=get_trt_profiles_markdown()
                )

        button_refresh_profiles.click(
            lambda: gr.Markdown.update(value=get_trt_profiles_markdown()),
            outputs=[trt_profiles_markdown],
        )

        button_export_unet.click(
            export_unet_to_trt,
            inputs=[
                trt_min_batch,
                trt_opt_batch,
                trt_max_batch,
                trt_height_min,
                trt_height_opt,
                trt_height_max,
                trt_width_min,
                trt_width_opt,
                trt_width_max,
                trt_token_count_min,
                trt_token_count_opt,
                trt_token_count_max,
                force_rebuild,
                static_shapes,
                version,
            ],
            outputs=[trt_result],
        )

        button_export_default_unet.click(
            export_unet_to_trt,
            inputs=[
                trt_min_batch,
                trt_opt_batch,
                trt_max_batch,
                trt_height_min,
                trt_height_opt,
                trt_height_max,
                trt_width_min,
                trt_width_opt,
                trt_width_max,
                trt_token_count_min,
                trt_token_count_opt,
                trt_token_count_max,
                force_rebuild,
                static_shapes,
                version,
            ],
            outputs=[trt_result],
        )

        button_export_lora_unet.click(
            export_lora_to_trt,
            inputs=[trt_lora_dropdown, trt_lora_force_rebuild],
            outputs=[trt_result],
        )

    return [(trt_interface, "TensorRT", "tensorrt")]

exporter.py 替换代码(直接复制就好了,别忘了删掉 CSDN 水印):

import os
import time
import shutil
import json
from pathlib import Path
from logging import info, error
from collections import OrderedDict
from typing import List, Tuple

import torch
import torch.nn.functional as F
import numpy as np
import onnx
from onnx import numpy_helper
from optimum.onnx.utils import (
    _get_onnx_external_data_tensors,
    check_model_uses_external_data,
)


from modules import shared

from utilities import Engine
from datastructures import ProfileSettings
from model_helper import UNetModel


def apply_lora(model: torch.nn.Module, lora_path: str, inputs: Tuple[torch.Tensor]) -> torch.nn.Module:
    try:
        import sys

        sys.path.append("extensions-builtin/Lora")
        import importlib

        networks = importlib.import_module("networks")
        network = importlib.import_module("network")
        lora_net = importlib.import_module("extra_networks_lora")
    except Exception as e:
        error(e)
        error("LoRA not found. Please install LoRA extension first from ...")
    model.forward(*inputs)
    lora_name = os.path.splitext(os.path.basename(lora_path))[0]
    networks.load_networks(
        [lora_name], [1.0], [1.0], [None]
    )

    model.forward(*inputs)
    return model


def get_refit_weights(
    state_dict: dict, onnx_opt_path: str, weight_name_mapping: dict, weight_shape_mapping: dict
) -> dict:
    refit_weights = OrderedDict()
    onnx_opt_dir = os.path.dirname(onnx_opt_path)
    onnx_opt_model = onnx.load(onnx_opt_path)
    # Create initializer data hashes
    initializer_hash_mapping = {}
    onnx_data_mapping = {}
    for initializer in onnx_opt_model.graph.initializer:
        initializer_data = numpy_helper.to_array(
            initializer, base_dir=onnx_opt_dir
        ).astype(np.float16)
        initializer_hash = hash(initializer_data.data.tobytes())
        initializer_hash_mapping[initializer.name] = initializer_hash
        onnx_data_mapping[initializer.name] = initializer_data

    for torch_name, initializer_name in weight_name_mapping.items():
        initializer_hash = initializer_hash_mapping[initializer_name]
        wt = state_dict[torch_name]

        # get shape transform info
        initializer_shape, is_transpose = weight_shape_mapping[torch_name]
        if is_transpose:
            wt = torch.transpose(wt, 0, 1)
        else:
            wt = torch.reshape(wt, initializer_shape)

        # include weight if hashes differ
        wt_hash = hash(wt.cpu().detach().numpy().astype(np.float16).data.tobytes())
        if initializer_hash != wt_hash:
            delta = wt - torch.tensor(onnx_data_mapping[initializer_name]).to(wt.device)
            refit_weights[initializer_name] = delta.contiguous()

    return refit_weights


def export_lora(
    modelobj: UNetModel,
    onnx_path: str,
    weights_map_path: str,
    lora_name: str,
    profile: ProfileSettings,
) -> dict:
    info("Exporting to ONNX...")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    inputs = modelobj.get_sample_input(
        profile.bs_opt * 2,
        profile.h_opt // 8,
        profile.w_opt // 8,
        profile.t_opt,
    )
    inputs = tuple(input_tensor.to(device) for input_tensor in inputs)
    modelobj.unet.to(device)

    with open(weights_map_path, "r") as fp_wts:
        print(f"[I] Loading weights map: {weights_map_path} ")
        [weights_name_mapping, weights_shape_mapping] = json.load(fp_wts)

    with torch.inference_mode(), torch.autocast("cuda"):
        modelobj.unet = apply_lora(
            modelobj.unet, os.path.splitext(lora_name)[0], inputs
        )

        refit_dict = get_refit_weights(
            modelobj.unet.state_dict(),
            onnx_path,
            weights_name_mapping,
            weights_shape_mapping,
        )

    return refit_dict


def swap_sdpa(func):
    def wrapper(*args, **kwargs):
        swap_sdpa = hasattr(F, "scaled_dot_product_attention")
        old_sdpa = (
            getattr(F, "scaled_dot_product_attention", None) if swap_sdpa else None
        )
        if swap_sdpa:
            delattr(F, "scaled_dot_product_attention")
        ret = func(*args, **kwargs)
        if swap_sdpa and old_sdpa:
            setattr(F, "scaled_dot_product_attention", old_sdpa)
        return ret

    return wrapper


@swap_sdpa
def export_onnx(
    onnx_path: str,
    modelobj: UNetModel,
    profile: ProfileSettings,
    opset: int = 17,
    diable_optimizations: bool = False,
):
    info("Exporting to ONNX...")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    inputs = modelobj.get_sample_input(
        profile.bs_opt * 2,
        profile.h_opt // 8,
        profile.w_opt // 8,
        profile.t_opt,
    )
    inputs = tuple(input_tensor.to(device) for input_tensor in inputs)
    modelobj.unet.to(device)

    if not os.path.exists(onnx_path):
        _export_onnx(
            modelobj.unet,
            inputs,
            Path(onnx_path),
            opset,
            modelobj.get_input_names(),
            modelobj.get_output_names(),
            modelobj.get_dynamic_axes(),
            modelobj.optimize if not diable_optimizations else None,
        )


def _export_onnx(
    model: torch.nn.Module, inputs: Tuple[torch.Tensor], path: str, opset: int, in_names: List[str], out_names: List[str], dyn_axes: dict, optimizer=None
):
    tmp_dir = os.path.abspath("onnx_tmp")
    os.makedirs(tmp_dir, exist_ok=True)
    tmp_path = os.path.join(tmp_dir, "model.onnx")
    try:
        info("Exporting to ONNX...")
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model.to(device)
        inputs = tuple(inp.to(device) for inp in inputs)
        with torch.inference_mode(), torch.autocast("cuda"):
            torch.onnx.export(
                model,
                inputs,
                tmp_path,
                export_params=True,
                opset_version=opset,
                do_constant_folding=True,
                input_names=in_names,
                output_names=out_names,
                dynamic_axes=dyn_axes,
            )
    except Exception as e:
        error("Exporting to ONNX failed. {}".format(e))
        return

    info("Optimize ONNX.")
    os.makedirs(path.parent, exist_ok=True)
    onnx_model = onnx.load(tmp_path, load_external_data=False)
    model_uses_external_data = check_model_uses_external_data(onnx_model)

    if model_uses_external_data:
        info("ONNX model uses external data. Saving as external data.")
        tensors_paths = _get_onnx_external_data_tensors(onnx_model)
        onnx_model = onnx.load(tmp_path, load_external_data=True)
        onnx.save(
            onnx_model,
            str(path),
            save_as_external_data=True,
            all_tensors_to_one_file=True,
            location=path.name + "_data",
            size_threshold=1024,
        )

    if optimizer is not None:
        try:
            onnx_opt_graph = optimizer("unet", onnx_model)
            onnx.save(onnx_opt_graph, path)
        except Exception as e:
            error("Optimizing ONNX failed. {}".format(e))
            return

    if not model_uses_external_data and optimizer is None:
        shutil.move(tmp_path, str(path))

    shutil.rmtree(tmp_dir)


def export_trt(trt_path: str, onnx_path: str, timing_cache: str, profile: dict, use_fp16: bool):
    engine = Engine(trt_path)

    # TODO Still approx. 2gb of VRAM unaccounted for...
    model = shared.sd_model.cpu()
    torch.cuda.empty_cache()

    s = time.time()
    ret = engine.build(
        onnx_path,
        use_fp16,
        enable_refit=True,
        enable_preview=True,
        timing_cache=timing_cache,
        input_profile=[profile],
        # hwCompatibility=hwCompatibility,
    )
    e = time.time()
    info(f"Time taken to build: {(e-s)}s")

    shared.sd_model = model.cuda()
    return ret

三、保存,重启 WEB UI

1.温馨提示

我的显卡是 3080 10G,但是却连 512*512 的 tensorRT 都艰难地跑,速度确实变快了,但太难受了兄弟们,达不到要求。
你的显卡显存如果不够,不建议使用 tensorRT,是一种折磨

  • 5
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 5
    评论
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

MicroLindb

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值