Liger Kernel使用Triton优化trnasformers训练

Liger (Linkedin GPU Efficient Runtime) Kernel is a collection of Triton kernels designed specifically for LLM training. We have implemented Hugging Face Compatible RMSNorm, RoPE, SwiGLU, CrossEntropy, FusedLinearCrossEntropy, and more to come. It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. The kernel works out of the box with flash attention, PyTorch FSDP, and Microsoft DeepSpeed. We welcome contributions from the community to gather the best kernels for LLM training.

Liger Kernel用于加速训练,减小训练峰值显存,主要就是在transformers包的基础上,从liger_kernel.transformers包中,导入了一个apply_liger_kernel_to_xxx()的函数,用于替换原始transformers包写的各个组件,这些组件都是用triton重写加速过的。

举个例子,下面是apply_liger_kernel_to_llama()函数的内容:就是把原始modeling_xxx的ROPE、RMS_Norm、Swiglu、Cross_Entorpy等模块给替换掉。

def apply_liger_kernel_to_llama(
    rope: bool = True,
    cross_entropy: bool = False,
    fused_linear_cross_entropy: bool = True,
    rms_norm: bool = True,
    swiglu: bool = True,
    model: PreTrainedModel = None,
) -> None:
    """
    Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)

    Args:
        rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
        cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
        fused_linear_cross_entropy (bool):
            Whether to apply Liger's fused linear cross entropy loss. Default is True.
            `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
            If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
        rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
        swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
        model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
        loaded. Default is None.
    """

    assert not (
        cross_entropy and fused_linear_cross_entropy
    ), "cross_entropy and fused_linear_cross_entropy cannot both be True."

    from transformers.models.llama import modeling_llama

    if rope:
        modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
    if rms_norm:
        modeling_llama.LlamaRMSNorm = LigerRMSNorm
    if swiglu:
        modeling_llama.LlamaMLP = LigerSwiGLUMLP
    if cross_entropy:
        modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
    if fused_linear_cross_entropy:
        modeling_llama.LlamaForCausalLM.forward = llama_lce_forward

    if model is not None:
        # The model instance already exists, so we need to additionally patch the
        # instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
        config: PretrainedConfig = model.config

        if hasattr(model, "model"):
            # The case for LlamaForCausalLM or LlamaForSequenceClassification, for example
            base_model = model.model
        elif hasattr(model, "transformer"):
            # LlamaForQuestionAnswering uses "transformer" instead of "model"
            base_model = model.transformer
        else:
            # Direct LlamaModel
            base_model = model

        torch_dtype = config.torch_dtype
        if rms_norm:
            base_model.norm = LigerRMSNorm(
                config.hidden_size, eps=config.rms_norm_eps
            ).to(torch_dtype)

        for decoder_layer in base_model.layers:
            if swiglu:
                decoder_layer.mlp = LigerSwiGLUMLP(config).to(torch_dtype)
            if rms_norm:
                decoder_layer.input_layernorm = LigerRMSNorm(
                    config.hidden_size, eps=config.rms_norm_eps
                ).to(torch_dtype)
                decoder_layer.post_attention_layernorm = LigerRMSNorm(
                    config.hidden_size, eps=config.rms_norm_eps
                ).to(torch_dtype)

那替换掉的模型里面具体是怎么加速的呢?我们深入MLP来看一下:可以看到Liger实现的MLP在init部分没有任何修改,在forward的时候apply了自己定义的算子ops,优化原始的gate_proj和up_proj的乘法操作。
在这里插入图片描述

在自己定义的算子中,就是导入triton,自己重写forwardbackward

class LigerSiLUMulFunction(torch.autograd.Function):
    @staticmethod
    @ensure_contiguous
    def forward(ctx, a, b):
        a, b, c = swiglu_forward(a, b)
        ctx.save_for_backward(a, b)
        return c

    @staticmethod
    @ensure_contiguous
    def backward(ctx, dc):
        a, b = ctx.saved_tensors
        a, b = swiglu_backward(a, b, dc)
        return a, b
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Yuezero_

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值