测试模型:
https://huggingface.co/RWKV/rwkv-5-world-3b
导出前对modeling_rwkv5.py进行一个修改:
# out = out.reshape(B * T, H * S)
out = out.reshape(B * T, H * S, 1) # <<--- modified
out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, T, H * S)
因为目前存pytorch导出onnx在bug,不支持2d输入的group_norm导出。
注意:
rwkv_linear_attention_v5_cpu中使用 for t in range(T):来拆分计算,这导致首次prompt和后续decoding阶段导出的onnx模型结构不一样。这部分需要改进后才能导出同时适用于prompt和decoding的onnx。
if hidden.size(1) == 1这样的判断逻辑也可能导致上述问题。
此外,为了高效的推理,这个rwkv还可以进一步优化,例如state是把按照
state[1][:, :, :, :, self.layer_id] = layer_state
更新每一层的状态,这种方法比把layer_id放在最外层性能是显著更差的:
state[1][self.layer_id, :, :, :, :] = layer_state
甚至说可以就像transformer架构模型一样,直接把每一层的layer_state单独存在一个List里面,虽然增加了模型的输入输出个数,但是避免了复杂的ScatterND算子。
导出代码参考(可以尝试device=cpu导出):
import os
import argparse
import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer
class LLMForCausalLMWrapper(nn.Module):
def __init__(self, model, config, args):
super().__init__()
self.model = model
self.config = config
self.args = args
def forward(
self,
input_ids,
state,
):
outputs = self.model(
input_ids=input_ids,
state=state,
use_cache=True,
)
logits = outputs.logits
state_out = outputs.state
return logits, state_out
def export_llm_to_single_onnx(model, config, dtype, args, model_name):
llama_model_wrapper = LLMForCausalLMWrapper(model, config, args)
onnx_file_name = os.path.join(args.out_dir, f"{model_name}.onnx")
hidden_size = config.hidden_size
layer_num = config.num_hidden_layers
head_num = config.hidden_size // config.num_attention_heads
head_hidden_size = config.hidden_size // head_num
batch = 1
N = 4
input_ids_shape = [batch, N]
input_ids = torch.ones(input_ids_shape, dtype=torch.int64).to(args.device)
dynamic_axes = {
'input_ids': {1: 'N', },
}
if args.dyn_batch:
dynamic_axes['input_ids'][0] = "batch"
state_0 = torch.randn([batch, hidden_size, layer_num], dtype=dtype).to(args.device)
state_1 = torch.randn([batch, head_num, head_hidden_size, head_hidden_size, layer_num], dtype=dtype).to(args.device)
state_2 = torch.randn([batch, hidden_size, layer_num], dtype=dtype).to(args.device)
state = [state_0, state_1, state_2]
in_names = ["input_ids", "state_0_in", "state_1_in", "state_2_in"]
kv_caches_in = []
out_names = ["lm_logits", "state_0_out", "state_1_out", "state_2_out"]
input_datas = (input_ids, state)
torch.onnx.export(
llama_model_wrapper,
input_datas,
onnx_file_name,
opset_version=args.opset,
do_constant_folding=True,
input_names=in_names,
output_names=out_names,
dynamic_axes=dynamic_axes,
)
def export_rwkv(args):
device = args.device
dtype_map = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
dtype = dtype_map[args.dtype]
print(f"begin load model from {args.model_path}")
model = AutoModelForCausalLM.from_pretrained(
args.model_path, device_map=device, torch_dtype=dtype, trust_remote_code=True).eval()
model.rwkv.blocks = model.rwkv.blocks[:1] # only export few layer for debug
print(f"finish load model from {args.model_path}")
config = model.config
print("config:", config)
print(f"begin export llm")
export_llm_to_single_onnx(model, config, dtype, args, "llm_onnx")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='export llm',
)
parser.add_argument('-m', '--model_path', required=True, type=str)
parser.add_argument('-o', '--out_dir', required=False, type=str, default="")
parser.add_argument('--opset', required=False, type=int, default=15)
parser.add_argument('-d', '--device', required=False, type=str, choices=["cpu", "cuda"], default="cuda")
parser.add_argument('-p', '--dtype', required=False, type=str,
choices=["float32", "float16", "bfloat16"], default="float16")
parser.add_argument('--add_topk_warper', required=False, type=int, default=0)
parser.add_argument('--topk', required=False, type=int, default=4)
parser.add_argument('--dyn_batch', action='store_true')
args = parser.parse_args()
export_rwkv(args)
导出其他模型和对大模型进行onnxsim参考:
GitHub - luchangli03/export_llama_to_onnx: export llama to onnx
GitHub - luchangli03/onnxsim_large_model: simplify >2GB large onnx model