导出谷歌gemma模型为ONNX

文章讲述了如何通过修改Gemma模型代码,适应Transformers库新版本,优化logits计算并简化注意力mask处理,以提高ONNX模型的性能。
摘要由CSDN通过智能技术生成

参考代码如下(从GitHub - luchangli03/export_llama_to_onnx: export llama to onnx修改而来,后面会合入进去)

模型权重链接参考:

https://huggingface.co/google/gemma-2b-it

可以对modeling_gemma.py进行一些修改(transformers升级为最新版本内置该模型代码),从而提升导出的onnx性能:

1,GemmaForCausalLM中原始的logits计算为:

        hidden_states = outputs[0]
        logits = self.lm_head(hidden_states)

修改为:

        hidden_states = outputs[0]
        hidden_states = hidden_states[:,-1:,:]
        logits = self.lm_head(hidden_states)

这样使得降低prefill阶段lm_head的计算量。

2,模型使用了GemmaSdpaAttention,导出的onnx模型从一个很大的张量中索引向量仅仅用作attention mask:

causal_mask = attention_mask
if attention_mask is not None and cache_position is not None:
    causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]

这里即增加了存储又增加了计算。实际上可以直接把扩展后的attention mask作为onnx输入传入进来,从而完全消除这个存储和计算。

不知为何很多模型(例如千问等)都输入一个[1, seq_len]的向量,然后内部扩展为一个[1,1, seq_len, sumN]的mask,这些操作都可以直接替换为模型直接采用[1,1, seq_len, sumN]的mask输入。

这里对modeling_gemma.py修改方法为:

class GemmaModel(GemmaPreTrainedModel):
    def forward(
        # causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
        causal_mask = attention_mask

class GemmaSdpaAttention(GemmaAttention):
    def forward(
        # if attention_mask is not None and cache_position is not None:
        #     causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]

模型导出代码(进行了上述修改,如果不想修改的话,修改下这里面的atten mask的shape,dtype即可):

import os
import argparse
import logging
import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer


class LLMForCausalLMWrapper(nn.Module):
    def __init__(self, model, config, args):
        super().__init__()
        self.model = model
        self.config = config
        self.args = args

    def forward(
        self,
        input_ids,
        attention_mask,
        position_ids,
        past_key_values,
        output_attentions=False,
        output_hidden_states=False,
        use_cache=True,
    ):
        """
        Note: you can modify modeling_gemma.py to make the converted model more efficient:
        hidden_states = outputs[0]
        hidden_states = hidden_states[:,-1:,:]
        logits = self.lm_head(hidden_states)
        """
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=None,
            use_cache=True,
        )

        logits = outputs.logits
        kv_caches_out = []
        for past_kv in outputs.past_key_values:
            kv_caches_out.extend(past_kv)

        topk_outputs = []
        if self.args.add_topk_warper > 0:
            logging.warning("add topk to glm model")
            if self.args.topk < 0:
                raise ValueError("topk {} is invalid")
            topk_outputs = torch.topk(logits, k=self.args.topk, dim=-1)

        return logits, *kv_caches_out, *topk_outputs


def export_llm_to_single_onnx(model, config, dtype, args, model_name):
    llama_model_wrapper = LLMForCausalLMWrapper(model, config, args)

    onnx_file_name = os.path.join(args.out_dir, f"{model_name}.onnx")

    layer_num = len(model.model.layers)

    hidden_size = config.hidden_size
    kv_head_num = config.num_key_value_heads
    head_dim = config.head_dim

    batch = 1
    N = 1
    sumN = 32
    lastSum = sumN - N

    input_ids_shape = [batch, N]
    input_ids = torch.ones(input_ids_shape, dtype=torch.int64).to(args.device)
    # Note: orig atten_mask shape is [1, sumN]
    attention_mask = torch.randn([batch, 1, N, sumN], dtype=dtype).to(args.device)
    position_ids = torch.ones([batch, N], dtype=torch.int64).to(args.device)

    in_names = ["input_ids", "attention_mask", "position_ids"]

    dynamic_axes = {
        'input_ids': {1: 'N', },
        'attention_mask': {2: 'N', 3: 'sumN'},
        "position_ids": {1: 'N', },
    }
    if args.dyn_batch:
        dynamic_axes['input_ids'][0] = "batch"
        dynamic_axes['attention_mask'][0] = "batch"
        dynamic_axes['position_ids'][0] = "batch"

    kv_caches_in = []
    out_names = ["lm_logits"]

    kv_cache_in_shape = [1, kv_head_num, lastSum, head_dim]
    kv_cache_dyn_axes = {2: "sumN-N"}

    if args.dyn_batch:
        kv_cache_dyn_axes[0] = "batch"

    past_key_values = []

    for i in range(layer_num):
        past_key_in = torch.randn(kv_cache_in_shape, dtype=dtype).to(args.device)
        past_value_in = torch.randn(kv_cache_in_shape, dtype=dtype).to(args.device)

        kv_caches_in.extend([past_key_in, past_value_in])
        in_names.extend([f"past_key_in{i}", f"past_value_in{i}"])
        out_names.extend([f"past_key{i}", f"past_value{i}"])

        dynamic_axes[f"past_key_in{i}"] = kv_cache_dyn_axes
        dynamic_axes[f"past_value_in{i}"] = kv_cache_dyn_axes

        past_key_values.append((past_key_in, past_value_in))

    input_datas = (input_ids, attention_mask, position_ids, past_key_values)

    torch.onnx.export(
        llama_model_wrapper,
        input_datas,
        onnx_file_name,
        opset_version=args.opset,
        do_constant_folding=True,
        input_names=in_names,
        output_names=out_names,
        dynamic_axes=dynamic_axes,
    )


def export_llama(args):
    device = args.device
    dtype_map = {
        "float32": torch.float32,
        "float16": torch.float16,
        "bfloat16": torch.bfloat16,
    }
    dtype = dtype_map[args.dtype]

    print(f"begin load model from {args.model_path}")
    model = AutoModelForCausalLM.from_pretrained(
        args.model_path, device_map=device, torch_dtype=dtype, trust_remote_code=True).eval()

    # model.model.layers = model.model.layers[:1]  # only export one layer for debug

    print(f"finish load model from {args.model_path}")
    config = model.config
    print("config:", config)

    print(f"begin export llm")
    export_llm_to_single_onnx(model, config, dtype, args, "llm_onnx")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description='export llm',
    )
    parser.add_argument('-m', '--model_path', required=True, type=str)
    parser.add_argument('-o', '--out_dir', required=False, type=str, default="")
    parser.add_argument('--opset', required=False, type=int, default=15)
    parser.add_argument('-d', '--device', required=False, type=str, choices=["cpu", "cuda"], default="cuda")
    parser.add_argument('-p', '--dtype', required=False, type=str,
                        choices=["float32", "float16", "bfloat16"], default="float16")
    parser.add_argument('--add_topk_warper', required=False, type=int, default=0)
    parser.add_argument('--topk', required=False, type=int, default=4)
    parser.add_argument('--dyn_batch', action='store_true')

    args = parser.parse_args()
    export_llama(args)

导出的onnx文件onnxsim:

GitHub - luchangli03/onnxsim_large_model: simplify >2GB large onnx model

导出的onnx模型推理示例(依赖文件在GitHub - luchangli03/export_llama_to_onnx: export llama to onnx

import numpy as np
from onnx_rt_utils import OnnxRuntimeModel, get_random_data
from sample_utils import sample_topk
from transformers import AutoTokenizer


def prepare_kv_cache_round0(glm_model_inputs, layer_num, lastSum):
    """
    only used at the first time
    in round 0, actually the lastSum is 0, thus past_key_in, past_value_in are empty tensor
    """
    for i in range(layer_num):
        past_key_in = get_random_data([1, 1, lastSum, 256], "float16")
        past_value_in = get_random_data([1, 1, lastSum, 256], "float16")

        past_key_in_name = f"past_key_in{i}"
        past_value_in_name = f"past_value_in{i}"
        glm_model_inputs[past_key_in_name] = past_key_in
        glm_model_inputs[past_value_in_name] = past_value_in
    return glm_model_inputs


def prepare_kv_cache_from_outputs(glm_model_inputs, decoder_outputs, layer_num):
    offset = 1
    for i in range(layer_num):
        past_key_in_name = f"past_key_in{i}"
        past_value_in_name = f"past_value_in{i}"

        glm_model_inputs[past_key_in_name] = decoder_outputs[offset + i * 2]
        glm_model_inputs[past_value_in_name] = decoder_outputs[offset + i * 2 + 1]
    return glm_model_inputs


def get_atten_mask(N,  sumN,  padded_len):
    attention_mask = np.zeros(shape=[N * padded_len], dtype="float16")

    pad_num = padded_len - sumN
    if (N == sumN):
        for i in range(N):
            mask_num = N - 1 - i + pad_num
            start = padded_len - mask_num
            for j in range(start, padded_len):
                attention_mask[i * padded_len + j] = -65504
    else:
        if (N != 1):
            raise ValueError("N is not 1")
        lastSum = sumN - N
        for i in range(pad_num):
            attention_mask[lastSum + i] = -65504

    attention_mask = attention_mask.reshape([N, padded_len])
    return attention_mask


# all decoder layer num
layer_num = 18
eos_token_id = 2

pt_model_path = r"E:\test_models\llama\gemma-2b-it"
onnx_model_path = "llm_onnx.onnx"

prompt = "Write me a poem about Machine Learning."
tokenizer = AutoTokenizer.from_pretrained(pt_model_path, trust_remote_code=True)
input_ids = tokenizer(prompt)['input_ids']

print(input_ids)

input_ids = np.array(input_ids).reshape([1, -1]).astype("int64")

N = input_ids.shape[1]
sumN = N
lastSum = sumN - N
print("N:", N, sumN, lastSum)

position_ids = np.arange(sumN).reshape([1, -1]).astype("int64")

input_ids = input_ids.astype("int64")
position_ids = position_ids.astype("int64")

glm_model = OnnxRuntimeModel(onnx_model_path)

max_seq = 32

glm_model_inputs = {}

gen_tokens = []

for i in range(max_seq):
    print("input_ids:", input_ids)
    print("position_ids:", position_ids)

    attention_mask = get_atten_mask(N, sumN, padded_len=sumN).astype("float16")
    print("attention_mask:", attention_mask)
    attention_mask = attention_mask.reshape([1, 1, N, sumN])

    glm_model_inputs["input_ids"] = input_ids
    glm_model_inputs["attention_mask"] = attention_mask
    glm_model_inputs["position_ids"] = position_ids

    if i == 0:
        glm_model_inputs = prepare_kv_cache_round0(glm_model_inputs, layer_num, lastSum)

    glm_model_outputs = glm_model(**glm_model_inputs)
    lm_logits = glm_model_outputs[0]
    print("lm_logits:", lm_logits)

    next_token = sample_topk(lm_logits, topk=1)
    gen_tokens.append(next_token)
    print("next_token:", next_token)

    if next_token == eos_token_id:
        break

    input_ids = np.array([next_token]).astype("int64").reshape([-1, 1])
    position_ids = np.array([sumN]).astype("int64").reshape([-1, 1])

    N = 1
    sumN += 1
    prepare_kv_cache_from_outputs(glm_model_inputs, glm_model_outputs, layer_num)

gen_text = tokenizer.decode(gen_tokens)
print("Q:", prompt)
print("A:", gen_text)

评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Luchang-Li

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值