导出RWKV模型为onnx

测试模型:

https://huggingface.co/RWKV/rwkv-5-world-3b

导出前对modeling_rwkv5.py进行一个修改:

#        out = out.reshape(B * T, H * S)
        out = out.reshape(B * T, H * S, 1) # <<--- modified
        out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, T, H * S)

因为目前存pytorch导出onnx在bug,不支持2d输入的group_norm导出。

注意:

rwkv_linear_attention_v5_cpu中使用 for t in range(T):来拆分计算,这导致首次prompt和后续decoding阶段导出的onnx模型结构不一样。这部分需要改进后才能导出同时适用于prompt和decoding的onnx。

if hidden.size(1) == 1这样的判断逻辑也可能导致上述问题。

此外,为了高效的推理,这个rwkv还可以进一步优化,例如state是把按照

        state[1][:, :, :, :, self.layer_id] = layer_state
更新每一层的状态,这种方法比把layer_id放在最外层性能是显著更差的:

        state[1][self.layer_id, :, :, :, :] = layer_state

甚至说可以就像transformer架构模型一样,直接把每一层的layer_state单独存在一个List里面,虽然增加了模型的输入输出个数,但是避免了复杂的ScatterND算子。

导出代码参考(可以尝试device=cpu导出):

import os
import argparse
import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer


class LLMForCausalLMWrapper(nn.Module):
    def __init__(self, model, config, args):
        super().__init__()
        self.model = model
        self.config = config
        self.args = args

    def forward(
        self,
        input_ids,
        state,
    ):
        outputs = self.model(
            input_ids=input_ids,
            state=state,
            use_cache=True,
        )
        logits = outputs.logits
        state_out = outputs.state
        return logits, state_out


def export_llm_to_single_onnx(model, config, dtype, args, model_name):
    llama_model_wrapper = LLMForCausalLMWrapper(model, config, args)

    onnx_file_name = os.path.join(args.out_dir, f"{model_name}.onnx")

    hidden_size = config.hidden_size
    layer_num = config.num_hidden_layers
    head_num = config.hidden_size // config.num_attention_heads
    head_hidden_size = config.hidden_size // head_num

    batch = 1
    N = 4

    input_ids_shape = [batch, N]
    input_ids = torch.ones(input_ids_shape, dtype=torch.int64).to(args.device)

    dynamic_axes = {
        'input_ids': {1: 'N', },
    }
    if args.dyn_batch:
        dynamic_axes['input_ids'][0] = "batch"

    state_0 = torch.randn([batch, hidden_size, layer_num], dtype=dtype).to(args.device)
    state_1 = torch.randn([batch, head_num, head_hidden_size, head_hidden_size, layer_num], dtype=dtype).to(args.device)
    state_2 = torch.randn([batch, hidden_size, layer_num], dtype=dtype).to(args.device)

    state = [state_0, state_1, state_2]
    in_names = ["input_ids", "state_0_in", "state_1_in", "state_2_in"]

    kv_caches_in = []
    out_names = ["lm_logits", "state_0_out", "state_1_out", "state_2_out"]

    input_datas = (input_ids, state)

    torch.onnx.export(
        llama_model_wrapper,
        input_datas,
        onnx_file_name,
        opset_version=args.opset,
        do_constant_folding=True,
        input_names=in_names,
        output_names=out_names,
        dynamic_axes=dynamic_axes,
    )


def export_rwkv(args):
    device = args.device
    dtype_map = {
        "float32": torch.float32,
        "float16": torch.float16,
        "bfloat16": torch.bfloat16,
    }
    dtype = dtype_map[args.dtype]

    print(f"begin load model from {args.model_path}")
    model = AutoModelForCausalLM.from_pretrained(
        args.model_path, device_map=device, torch_dtype=dtype, trust_remote_code=True).eval()

    model.rwkv.blocks = model.rwkv.blocks[:1]  # only export few layer for debug

    print(f"finish load model from {args.model_path}")
    config = model.config
    print("config:", config)

    print(f"begin export llm")
    export_llm_to_single_onnx(model, config, dtype, args, "llm_onnx")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description='export llm',
    )
    parser.add_argument('-m', '--model_path', required=True, type=str)
    parser.add_argument('-o', '--out_dir', required=False, type=str, default="")
    parser.add_argument('--opset', required=False, type=int, default=15)
    parser.add_argument('-d', '--device', required=False, type=str, choices=["cpu", "cuda"], default="cuda")
    parser.add_argument('-p', '--dtype', required=False, type=str,
                        choices=["float32", "float16", "bfloat16"], default="float16")
    parser.add_argument('--add_topk_warper', required=False, type=int, default=0)
    parser.add_argument('--topk', required=False, type=int, default=4)
    parser.add_argument('--dyn_batch', action='store_true')

    args = parser.parse_args()

    export_rwkv(args)

导出其他模型和对大模型进行onnxsim参考:

GitHub - luchangli03/export_llama_to_onnx: export llama to onnx

GitHub - luchangli03/onnxsim_large_model: simplify >2GB large onnx model

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Luchang-Li

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值