AWQ量化及AutoAWQ代码详解

AWQ量化出自mit韩松组内2023年关于LLM量化的一篇文章:AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration


在介绍量化之前,先简要的介绍一下模型的量化

1. 为什么要进行模型量化?量化有什么好处呢?

模型之所以要进行量化,是因为我们日常使用fp16(floating point 16)或者bf16(Brain Floating Point)训练模型,fp16有1个符号位,5个指数位,10个尾数位,表示范围为(finfo(resolution=0.001, min=-65504, max=65504, eps=0.000976562, smallest_normal=6.10352e-05, tiny=6.10352e-05, dtype=float16)。

而bf16有1个符号位,8个指数位,7个尾数位,可以表示的范围为:finfo(resolution=0.01, min=-3.38953e+38, max=3.38953e+38, eps=0.0078125, smallest_normal=1.17549e-38, tiny=1.17549e-38, dtype=bfloat16)

可以看到,fp16和bf16可以表示的范围是很大的,但是这也产生了两个问题:(1)内存(显存)占用开销很大;以作者实际使用的qwen1.5-32b模型来计算,32b的16位浮点模型需要64G的显存来加载,加上数据集和kv cache。大致需要4*4090 (96G)才能运行。这是一个较大的开销,4*4090的服务器在10w人名币左右。(2)模型的runtime开销较大16位的浮点模型在进行矩阵运算时,是十分耗时的,更甚者像squeezellm的观点,模型的整个runtime的bottleneck在于模型weight的load。squeezellm认为当模型weight的bits降低,可以显著加速模型的runtime。(但是我在用awq测试的时候,awq模型的运算速度更慢于16位的浮点模型,因为awq的weight 在进行gemm or gemv前,需要先dequant)。

在意识到16位浮点模型的这些劣势后,一些大佬就在想能不能把16位的浮点数转换成bits更少的整数类型,例如量化成int8(LLM.int8, SmoothQuant),int4(GPTQAWQ),3bits(SqueezeLLM)。更有甚之,用到了1/2bits量化(AQLM)。模型在量化之后可以缓解上面浮点模型的两个问题,当使用awq把模型从16位的浮点模型(16 bits)量化到int4(4 bits)时,模型的大小从64G变为了16G,模型大小变为原来的1/4。可以节省大量的出计算资源。

2. 如何给模型进行量化呢?

模型量化详解-CSDN博客

总的来讲的话,16位的浮点模型可以由低bits的数乘以一个scale得到。


2.1 awq主要思路

核心观点1:权重并不同等重要,仅有小部分显著权重对推理结果影响较大

作者指出,模型的权重并不同等重要,仅有0.1%~1%的小部分显著权重对模型输出精度影响较大。因此如果能有办法只对0.1%~1%这一小部分权重保持原来的精度(FP16),对其他权重进行低比特量化,就可以在保持精度几乎不变的情况下,大幅降低模型内存占用,并提升推理速度。这就涉及到一个问题,如何鉴别显著权重,常用的方法有三种

  • 随机挑选:听天由命,随机选出0.1%~1%的权重作为显著权重,当然这种方法很不科学。

  • 基于权重分布挑选:对权重矩阵(比如自注意力中的 𝑊𝑞 , 𝑊𝑘 , 𝑊𝑣 )中的元素按绝对值大小由大到小排序,绝对值越大越显著,选择前0.1%~1%的元素作为显著权重。

  • 基于激活值分布挑选:激活值就是与权重矩阵作matmul运算的输入

 作者对三种方式进行了测试(Tab 1),发现随机挑选的结果与RTN的结果差不多,基于权重W的量化与随机挑选的结果差不多。而基于激活值分布挑选weight的结果与fp16的精度差不多。

作者为了避免方法在实现上过于复杂,在挑选显著权重时,并非在“元素”级别进行挑选,而是在“通道(channel)”级别进行挑选,即权重矩阵的一行作为一个单位。在计算时,首先将激活值对每一列求绝对值的平均值,然后把平均值较大的一列对应的通道视作显著通道,保留FP16精度。对其他通道进行低比特量化

但另一个问题随之而来,如果权重矩阵中有的元素用FP16格式存储,有的用INT4格式存储,不仅存储时很麻烦,计算时取数也很麻烦。于是,作者想了一个变通的方法——Scaling。

核心观点2:量化时对显著权重进行放大可以降低量化误差

量化公示可以写长上面那样,其中 𝑁 是量化后的比特数, Δ 是量化因子(scaler)。 w′=Round(wΔ) 是量化过程, Δ⋅w′ 是反量化过程。原始的 w 、 Δ 和输入 𝑥 都是FP16格式,不会带来精度损失。整个过程的精度损失全部来源于量化过程中的 Round 取整函数,其误差近似成[0, 0.5]的均匀分布,期望为0.25,可以写作 RoundErr(⋅)∼0.25

考虑对于权重矩阵 w 中的单个元素 𝑤 ,引入一个缩放因子 𝑠>1 ,量化过程将 𝑤 与该因子相乘,写作 w′=Round(w𝑠/Δ′) ,相应地将反量化过程写作 Δ′⋅w′𝑠,这样在计算过程上是“等价”的,如公式2

公式1和公式2在计算过程中是一样的,但是仍然会有不一样的精度损失,可以写作:

因此,作者改变了思路:为了更加hardware-friendly,我们对所有权重均进行低比特量化,但是,在量化时,对于显著权重乘以较大的 𝑠 ,相当于降低其量化误差;同时,对于非显著权重,乘以较小的 𝑠 ,相当于给予更少的关注。这便是上一节提到的缩放(Scaling)方法。

算法: 自动计算scaling系数

按照作者的观点,激活值越大,对应通道越显著,就应该分配更大的缩放系数降低其量化误差。因此,作者统计了各通道的平均激活值(计算输入矩阵各列绝对值的平均值) sx ,并直接将此作为各通道的缩放系数。同时引入一个变量 𝛼 用于平衡显著通道和非显著通道的系数,由此,问题转化为优化L(s) 使用网格搜索\alpha。在源码中,在[0, 1]平均取20个数,分别设为\alpha,计算L(s)最小的为最佳\alpha。smoothquant与awq的思路一致,而smoothquant计算s的方式为:

 2.2 code

autoawq的量化过程从AwqQuantizer.init_quant()开始,self.awq_model是Qwen2AWQForCausalLM类(以qwen1.5为例),self.model是加载的qwen模型。

 def init_quant(self, n_samples=128, seqlen=512):
        modules = self.awq_model.get_model_layers(self.model) # return model.model.layers
        samples = get_calib_dataset(
            data=self.calib_data,
            tokenizer=self.tokenizer,
            n_samples=n_samples,
            block_size=seqlen,
            split=self.split,
            text_column=self.text_column,
        )
        samples = torch.cat(samples, dim=0)
-----------------------------------------------------------
        class Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module

            def forward(self, *args, **kwargs):
                # assume first input to forward is hidden states
                if len(args) > 0:
                    hidden_states = args[0]
                    del args
                else:
                    first_key = list(kwargs.keys())[0]
                    hidden_states = kwargs.pop(first_key)

                inps.append(hidden_states)
                layer_kwargs.update(kwargs)
                raise ValueError  # early exit to break later inference
-----------------------------------------------------------       
        return modules, layer_kwargs, inps

[STEP 1]:Get layer, extract linear modules, extract input features

把module里面的线性层用字典保存, _get_input_feat会把每一层的输入数据给提取保存。

# [STEP 1]: Get layer, extract linear modules, extract input features
    named_linears = get_named_linears(self.modules[i])
    # named_linears is the dictionary of named linear layers in the module, e.g. :
    """
    {'self_attn.q_proj': Linear(in_features=1024, out_features=1024, bias=True),
    'self_attn.k_proj': Linear(in_features=1024, out_features=1024, bias=True),
    'self_attn.v_proj': Linear(in_features=1024, out_features=1024, bias=True),
    'self_attn.o_proj': Linear(in_features=1024, out_features=1024, bias=False),
    'mlp.gate_proj': Linear(in_features=1024, out_features=2816, bias=False),
    'mlp.up_proj': Linear(in_features=1024, out_features=2816, bias=False),
    'mlp.down_proj': Linear(in_features=2816, out_features=1024, bias=False)}
    """

    # Filter out the linear layers we don't want to exclude
    named_linears = exclude_layers_to_not_quantize(
        named_linears, self.modules_to_not_convert
    )

    input_feat = self._get_input_feat(self.modules[i], named_linears)
    clear_memory()

[STEP 2]: Compute and apply scale list

module_config: List[Dict] = self.awq_model.get_layers_for_scaling(
    self.modules[i], input_feat, self.module_kwargs
)
# 上面的代码是把模型的层给抽取出来,纳入字典中, prev_op 就是前一个层, 
# layers 就是当前层的线性层, inp 就是层输入特征。
# module2inspect 是所有层的混合。

scales_list = [
    self._search_best_scale(self.modules[i], **layer)
    for layer in module_config
]
apply_scale(self.modules[i], scales_list, input_feat_dict=input_feat)

第2个step是计算每层的scaling,module_config是一个包含了当前层,前面层,输入特征的一个字典集合。然后开始找到最好的scale,首先把weight分组进行归一化,再在channel为度求得weight得mean。同时x作为input也计算在channel上的mean。计算fp16模型的输出用于比较得到最好的scale。

@torch.no_grad()
    def _search_best_scale(
        self,
        module,
        prev_op,
        layers: List[nn.Linear],
        inp: torch.Tensor,
        module2inspect=None,
        kwargs={},
    ):
        if module2inspect is None:
            assert len(layers) == 1
            module2inspect = layers[0]

        if "use_cache" in kwargs:
            kwargs.pop("use_cache")

        # Put x on the right device
        inp = inp.to(next(module2inspect.parameters()).device)

        # [STEP 1]: Compute per-channel mean of normalised weights
        # All layer weights are concatted together
        weight = torch.cat([_m.weight for _m in layers], dim=0)
        org_shape = weight.shape
        # The weights are reshaped to be organised by quantization group
        weight = weight.view(-1, self.group_size)
        # Calculates the relative magnitude of the weights within each of the quantization groups, 
        # and rescales each group individually so that each group has weights on a 0-1 scale.
        w_scale = weight.abs() / (weight.abs().amax(dim=1, keepdim=True) + 1e-6)
        # Resizes the rescaled weight matrix back up to its original dimensions
        w_scale = w_scale.view(org_shape)
        # Gets the average rescaled magnitude for each output channel
        w_mean = w_scale.mean(0)
        clear_memory(weight)

        # [STEP 2]: Compute per-channel mean of the input activation
        x_mean = inp.abs().view(-1, inp.shape[-1]).mean(0)

        # [STEP 3]: Compute output of module
        with torch.no_grad():
            module_kwargs = self._sanitize_kwargs(kwargs, module2inspect)

            fp16_output = module2inspect(inp, **module_kwargs)
            if isinstance(fp16_output, tuple):
                fp16_output = fp16_output[0]

        # [STEP 4]: Compute loss
        best_scales = self._compute_best_scale(
            inp, w_mean, x_mean, module2inspect, layers, fp16_output, module_kwargs
        )

        return (
            get_op_name(module, prev_op),
            tuple([get_op_name(module, m) for m in layers]),
            best_scales,
        )

_compute_best_scale使用网格搜索,对于公式(4),在网格搜索中直接使用x_mean的\alpha(ratio)作为s, \alpha作为平衡因子,而网格搜索是找到最好的\alpha使得量化完模型的输出和fp16模型的输出的差值最小(L2 Loss)。

def _compute_best_scale(
        self,
        x,
        w_mean,
        x_mean,
        module2inspect,
        linears2scale: List[nn.Linear],
        fp16_output,
        kwargs={},
    ):
        """
        Compute loss and select best scales

        L(s) = || Q(W * s) (s^-1 * X) - W * X ||
        Q: weight quantization function | pseudo_quantize_tensor(W * s)
        X: inputs from calib dataset    | X
        W: original weights in FP16     | layer
        s: per channel scaling factor   | s^-1 * X
        """
        n_grid = 20
        history = []
        best_ratio = -1
        best_scales = None
        best_error = float("inf")

        org_sd = {k: v.cpu() for k, v in module2inspect.state_dict().items()}

        device = x.device
        x_mean = x_mean.view(-1).to(device)
        w_mean = w_mean.view(-1).to(device)

        for ratio in range(n_grid):
            # create new scales
            ratio = ratio / n_grid

            # NOTE: s^-1 * x is fused here, according to paper
            if self.duo_scaling:
                scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp(min=1e-4)
            else:
                scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1)
            scales = scales / (scales.max() * scales.min()).sqrt()
            scales_view = scales.view(1, -1).to(device)

            # Q(W * s)
            for fc in linears2scale:
                fc.weight.mul_(scales_view)
                fc.weight.data = (
                    self.pseudo_quantize_tensor(fc.weight.data)[0] / scales_view
                )

            # W * X
            int_w_output = module2inspect(x, **kwargs)
            if isinstance(int_w_output, tuple):
                int_w_output = int_w_output[0]

            # compute mean squared error (L2 norm)
            loss = (
                (fp16_output - int_w_output).float().pow(2).mean().item()
            )  # NOTE: float prevents overflow

            history.append(loss)
            if loss < best_error:
                best_error = loss
                best_ratio = ratio
                best_scales = scales.clone()
            module2inspect.load_state_dict(org_sd)

        if best_ratio == -1:
            logging.debug(history)
            raise Exception

        assert torch.isnan(best_scales).sum() == 0, best_scales

        return best_scales.detach().cpu()

apply_scale对每层的weight进行scale处理,例如:

Test: 改变grid的大小

通过改变网格的大小,可以控制平衡因子的取值精度,设置grid分别为grid=10, 20(awq默认),40, 100得到的结果如下:

以结果来看,改变grid并没有带来精度的提升。

[STEP 3]: Compute and apply clipping list

也是使用网格搜索求的最合适的最大值,并裁减

for i_b in range(org_w_shape[0] // oc_batch_size):
            w = w_all[i_b * oc_batch_size : (i_b + 1) * oc_batch_size]

            org_max_val = w.abs().amax(dim=-1, keepdim=True)  # co, 1, n_group, 1

            best_max_val = org_max_val.clone()
            min_errs = torch.ones_like(org_max_val) * 1e9
            input_feat = input_feat.to(w.device)
            org_out = (input_feat * w).sum(dim=-1)  # co, n_token, n_group

            for i_s in range(int(max_shrink * n_grid)):
                max_val = org_max_val * (1 - i_s / n_grid)
                min_val = -max_val
                cur_w = torch.clamp(w, min_val, max_val)
                q_w = self.pseudo_quantize_tensor(cur_w)[0]
                cur_out = (input_feat * q_w).sum(dim=-1)

                # co, 1, n_group, 1
                err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape)
                del cur_w
                del cur_out
                cur_best_idx = err < min_errs
                min_errs[cur_best_idx] = err[cur_best_idx]
                best_max_val[cur_best_idx] = max_val[cur_best_idx]
            best_max_val_all.append(best_max_val)

        best_max_val = torch.cat(best_max_val_all, dim=0)
@torch.no_grad()
def apply_clip(module, clip_list: Tuple[str, torch.Tensor]):
    for name, max_val in clip_list:
        layer: nn.Linear = get_op_by_name(module, name)
        layer.to(get_best_device())
        max_val = max_val.to(layer.weight.device)
        org_shape = layer.weight.shape
        layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1)
        layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val)
        layer.weight.data = layer.weight.data.reshape(org_shape)
        layer.cpu()

[STEP 4]: Quantize weights

    def _apply_quant(self, module, named_linears: Dict[str, nn.Linear]):
        for name, linear_layer in named_linears.items():
            # NOTE: small regression in perplexity if linear layer uses .cpu().float()
            linear_layer = linear_layer.to(get_best_device()).half()

            linear_layer.weight.data, scales, zeros = self.pseudo_quantize_tensor(
                linear_layer.weight.data
            )

            if self.version == "gemm":
                scales = scales.t().contiguous()
                if zeros is not None:
                    zeros = zeros.t().contiguous()
                q_linear_module = WQLinear_GEMM

            elif self.version == "gemv":
                q_linear_module = WQLinear_GEMV

            elif self.version == "marlin":
                q_linear_module = WQLinear_Marlin
            
            elif self.version == "gemv_fast":
                q_linear_module = WQLinear_GEMVFast

            else:
                raise ValueError(f"Unknown version {self.version}")

            q_linear = q_linear_module.from_linear(
                linear=linear_layer,
                w_bit=self.w_bit,
                group_size=self.group_size,
                init_only=False,
                scales=scales,
                zeros=zeros,
            )

            linear_layer.cpu()
            q_linear.to(next(module.parameters()).device)
            set_op_by_name(module, name, q_linear)
            clear_memory()

在GEMM中,通过下面代码对weight进行分组量化,每个组共享一个scale。

pack_num = 32 // awq_linear.w_bit
intweight = []
        for idx in range(awq_linear.in_features):
            intweight.append(
                torch.round(
                    (linear.weight.data[:, idx] + scale_zeros[idx // group_size])
                    / awq_linear.scales[idx // group_size]
                ).to(torch.int)[:, None]
            )
        intweight = torch.cat(intweight, dim=1)
        intweight = intweight.t().contiguous()
        intweight = intweight.to(dtype=torch.int32)

而在量化完之后,定义一个int32类型的qweight用于存储量化之后的weight。因为量化后的weight需要储存为int4类型,所以一个qweight可以存储8个量化weight。通过移位操作可以用一个int32储存8个int4数据。

  qweight = torch.zeros(
            (intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit),
            dtype=torch.int32,
            device=intweight.device,
        )

        for col in range(intweight.shape[1] // pack_num):
            if awq_linear.w_bit == 4:
                order_map = [0, 2, 4, 6, 1, 3, 5, 7]
            else:
                raise NotImplementedError("Only 4-bit are supported for now.")
            for i in range(pack_num):
                qweight_col = intweight[:, col * pack_num + order_map[i]]
                qweight[:, col] |= qweight_col << (i * awq_linear.w_bit)
        awq_linear.qweight = qweight

gemm在前向传播和反向传播的时候,都需要先dequant再进行计算

class WQLinearMMFunction(Function):
    @staticmethod
    # ctx is the first argument to forward
    def forward(
        ctx,
        x,
        qweight,
        qzeros,
        scales,
        w_bit=4,
        group_size=128,
        bias=None,
        out_features=0,
    ):
        # The forward pass can use ctx.
        ctx.save_for_backward(x, qweight, qzeros, scales, bias)
        ctx.out_features = out_features

        out_shape = x.shape[:-1] + (out_features,)
        x = x.to(torch.float16)

        if AWQ_INSTALLED:
            FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024

            if FP16_MATMUL_HEURISTIC_CONDITION:
                out = awq_ext.dequantize_weights_cuda(
                    qweight, scales, qzeros, 0, 0, 0, False
                )
                out = torch.matmul(x, out)
            else:
                out = awq_ext.gemm_forward_cuda(
                    x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, 8
                )
        else:
            out = dequantize_gemm(qweight, qzeros, scales, w_bit, group_size)
            out = torch.matmul(x, out)

        out = out + bias if bias is not None else out
        out = out.reshape(out_shape)

        # always want 3D tensor if tensor is 2D
        if len(out.shape) == 2:
            out = out.unsqueeze(0)

        return out

  • 11
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
wandb: Tracking run with wandb version 0.15.5 wandb: W&B syncing is set to `offline` in this directory. wandb: Run `wandb online` or set WANDB_MODE=online to enable cloud syncing. /home/zhangmengjie/anaconda3/envs/torch1/lib/python3.7/site-packages/gym/envs/registration.py:556: UserWarning: WARN: The environment Ant-v2 is out of date. You should consider upgrading to version `v4`. f"The environment {id} is out of date. You should consider " Error compiling Cython file: ------------------------------------------------------------ ... See c_warning_callback, which is the C wrapper to the user defined function ''' global py_warning_callback global mju_user_warning py_warning_callback = warn mju_user_warning = c_warning_callback ^ ------------------------------------------------------------ /home/zhangmengjie/anaconda3/envs/torch1/lib/python3.7/site-packages/mujoco_py/cymj.pyx:92:23: Cannot assign type 'void (const char *) except * nogil' to 'void (*)(const char *) noexcept nogil' Error compiling Cython file: ------------------------------------------------------------ ... See c_warning_callback, which is the C wrapper to the user defined function ''' global py_error_callback global mju_user_error py_error_callback = err_callback mju_user_error = c_error_callback ^ ------------------------------------------------------------ /home/zhangmengjie/anaconda3/envs/torch1/lib/python3.7/site-packages/mujoco_py/cymj.pyx:127:21: Cannot assign type 'void (const char *) except * nogil' to 'void (*)(const char *) noexcept nogil' Compiling /home/zhangmengjie/anaconda3/envs/torch1/lib/python3.7/site-packages/mujoco_py/cymj.pyx because it changed. [1/1] Cythonizing /home/zhangmengjie/anaconda3/envs/torch1/lib/python3.7/site-packages/mujoco_py/cymj.pyx wandb: Waiting for W&B process to finish... (failed 1). wandb: You can sync this run to the cloud by running: wandb: wandb sync /home/zhangmengjie/PID/Python/ERL-Re2-main/wandb/offline-run-20230721_165346-awq1hazo wandb: Find logs at: ./wandb/offline-run-20230721_165346-awq1hazo/logs
07-22
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值