NF4量化算法的PyTorch实现

为了方便理解NF4算法的实现,这里用PyTorch实现了一版可以和CUDA NF4精度对齐的量化和反量化函数,并使用llama-3.1-8b模型进行测试,可以做到和CUDA实现的算子精度基本对齐(仅反量化存在少许误差),并对模型输出进行测试,64个tokens和CUDA实现完全一致。

以下都只是在RTX3090上对llama-3.1-8b上进行测试的结果,不能代表全部的设备和模型。

CUDA上使用dQuantizeNF4函数使用float类型的xfloat类型的NF4表的中间值进行比较,从而得到表中距离x的最近元素的索引。

__device__ unsigned char dQuantizeNF4(float x)
{
   
  // the values for this tree was generated by test_normal_map_tree
  // in the file tests/test_functional.py
  if(x > 0.03979014977812767f)
    if(x > 0.3893125355243683f) // 1
      if(x > 0.6427869200706482f) // 11
        if(x > 0.8614784181118011f) // 111
          return 0b1111;
        else
          return 0b1110;
      else
        if(x > 0.5016634166240692f) // 110
          return 0b1101;
        else
          return 0b1100;
    else
      if(x > 0.2035212516784668f) // 10
        if(x > 0.2920137718319893f) // 101
          return 0b1011;
        else
          return 0b1010;
      else
        if(x > 0.1202552504837513f) // 100
          return 0b1001;
        else
          return 0b1000;
  else
    if(x > -0.33967943489551544f) // 0
      if(x > -0.13791173323988914f) // 01
        if(x > -0.045525018125772476f) // 011
          return 0b0111;
        else
          return 0b0110;
      else
        if(x > -0.23460740596055984f) // 010
          return 0b0101;
        else
          return 0b0100;
    else
      if(x > -0.6106329262256622f) // 00
        if(x > -0.4599952697753906f) // 001
          return 0b0011;
        else
          return 0b0010;
      else
        if(x > -0.8480964004993439f) // 000
          return 0b0001;
        else
          return 0b0000;
}

因此在实现时也需要注意MAPPINGabsmax的类型都需要是float32,经过在实际的llama3权重数据上测试:

  • 量化函数PyTorch实现可以和CUDA实现精度对齐;
  • 反量化函数平均绝对误差大约在1.5e-6,不影响模型输出。
BNB_MAP = [-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

风好衣轻

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值