为了方便理解NF4算法的实现,这里用PyTorch实现了一版可以和CUDA NF4精度对齐的量化和反量化函数,并使用llama-3.1-8b模型进行测试,可以做到和CUDA实现的算子精度基本对齐(仅反量化存在少许误差),并对模型输出进行测试,64个tokens和CUDA实现完全一致。
以下都只是在RTX3090上对llama-3.1-8b上进行测试的结果,不能代表全部的设备和模型。
CUDA上使用dQuantizeNF4
函数使用float
类型的x
与float
类型的NF4表的中间值进行比较,从而得到表中距离x
的最近元素的索引。
__device__ unsigned char dQuantizeNF4(float x)
{
// the values for this tree was generated by test_normal_map_tree
// in the file tests/test_functional.py
if(x > 0.03979014977812767f)
if(x > 0.3893125355243683f) // 1
if(x > 0.6427869200706482f) // 11
if(x > 0.8614784181118011f) // 111
return 0b1111;
else
return 0b1110;
else
if(x > 0.5016634166240692f) // 110
return 0b1101;
else
return 0b1100;
else
if(x > 0.2035212516784668f) // 10
if(x > 0.2920137718319893f) // 101
return 0b1011;
else
return 0b1010;
else
if(x > 0.1202552504837513f) // 100
return 0b1001;
else
return 0b1000;
else
if(x > -0.33967943489551544f) // 0
if(x > -0.13791173323988914f) // 01
if(x > -0.045525018125772476f) // 011
return 0b0111;
else
return 0b0110;
else
if(x > -0.23460740596055984f) // 010
return 0b0101;
else
return 0b0100;
else
if(x > -0.6106329262256622f) // 00
if(x > -0.4599952697753906f) // 001
return 0b0011;
else
return 0b0010;
else
if(x > -0.8480964004993439f) // 000
return 0b0001;
else
return 0b0000;
}
因此在实现时也需要注意MAPPING
和absmax
的类型都需要是float32
,经过在实际的llama3权重数据上测试:
- 量化函数PyTorch实现可以和CUDA实现精度对齐;
- 反量化函数平均绝对误差大约在
1.5e-6
,不影响模型输出。
BNB_MAP = [-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0]
MAPPING = torch.tensor(BNB_MAP, device="cuda", dtype=torch.float32).view(1, -1)
def py_quantize_nf4(A, blocksize=64):
shape = A.shape
absmax = A.view(-1, blocksize).abs().max(dim=1, keepdim=True).values
a = A.view(-1, blocksize) / absmax.float()
diff = torch.abs(a.unsqueeze(-1) - MAPPING)
out = torch.argmin(diff, dim=-1)
out = out.reshape(-1, 2)
out = (out[:, 0] * 16 + out[:, 1]).to(torch.uint8)
return out, absmax, shape
def py_dequantize_nf4(A, absmax, shape, dtype, blocksize=64):
A = A.view(-1)
A = torch.stack([A // 16, A % 16], dim=1).to(torch.int32)
out = MAPPING.reshape(-1)[A]
out = out.view(-1, blocksize) * absmax.reshape(-1, 1)
out = out.reshape(*shape).to(dtype)
return out
在bitsandbytes中使用这两个函数对CUDA实现进行替换,可以达到模型输出64个tokens完全一致的效果:
<|begin_of_text|>Once upon a time, 20 years ago, I was a young, idealistic, and naive college student. I was also a young, idealistic, and naive college student who was a member of the Young Republicans Club. I was also a young, idealistic, and naive college student who was a member of the Young Republicans Club who was
不过PyTorch的实现存在一定的性能损失,8B模型的量化过程从CUDA实现的3s增加到PyTorch实现的10s;使用PyTorch实现的版本输出64 tokens需要28.012s(仅受反量化函数性能影响),而CUDA实现仅需3.65512s。
精度对比脚本:
import torch
from bitsandbytes import functional as F
BNB_MAP = [-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0]
MAPPING = torch.tensor(BNB_MAP, device="cuda", dtype=torch.float32).view(1, -1)
def py_quantize_nf4(A, blocksize=64):
shape = A.shape
absmax = A.view(-1, blocksize).abs().max(dim=1, keepdim=True).values
a = A.view(-1, blocksize) / absmax.float()
diff = torch.abs(a.unsqueeze(-1) - MAPPING)
out = torch.argmin(diff, dim=-1)
out = out.reshape(-1, 2)
out = (out[:, 0] * 16 + out[:, 1]).to(torch.uint8)
return out, absmax, shape
def py_dequantize_nf4(A, absmax, shape, dtype, blocksize=64):
A = A.view(-1)
A = torch.stack([A // 16, A % 16], dim=1).to(torch.int32)
out = MAPPING.reshape(-1)[A]
out = out.view(-1, blocksize) * absmax.reshape(-1, 1)
out = out.reshape(*shape).to(dtype)
return out
def quantize(A, blocksize):
out, state = F.quantize_4bit(
A,
absmax=None,
out=None,
blocksize=blocksize,
compress_statistics=False,
quant_type="nf4",
quant_storage=torch.uint8
)
out1, absmax1, shape1 = py_quantize_nf4(A, blocksize)
quant_error = (out1.view(-1).to(torch.int32) - out.view(-1).to(torch.int32)).abs().max().item()
absmax_error0 = (torch.abs(state.absmax.view(-1) - absmax1.view(-1))).max().item()
absmax_error1 = (torch.abs(state.absmax.view(-1) - absmax1.view(-1))).sum().item()
print(f"[+] quant error mse: {quant_error}, absmax error mse: {absmax_error0}, absmax error sum: {absmax_error1}")
return out, state
def dequantize(A, absmax, blocksize, quant_state):
out = F.dequantize_4bit(
A,
quant_state,
absmax,
out=None,
blocksize=blocksize,
quant_type="nf4",
)
out1 = py_dequantize_nf4(A, absmax, quant_state.shape, quant_state.dtype, blocksize)
# print((torch.abs(out - out1)).max().item(), (torch.abs(out - out1)).sum().item())
return out
def py_quantize_nf4_chunk(A, blocksize=64, chunk=8):
shape = A.shape
total_blocks = A.numel() // blocksize
chunks = (total_blocks + chunk - 1) // chunk
absmax_list = []
out_list = []
for i in range(chunks):
start = i * chunk * blocksize
end = min((i + 1) * chunk * blocksize, A.numel())
chunk_data = A.view(-1)[start:end].view(-1, blocksize)
absmax = chunk_data.abs().max(dim=1, keepdim=True).values
absmax_list.append(absmax)
a = chunk_data / absmax.float()
diff = torch.abs(a.unsqueeze(-1) - MAPPING)
out = torch.argmin(diff, dim=-1)
out = out.reshape(-1, 2)
out = (out[:, 0] * 16 + out[:, 1]).to(torch.uint8)
out_list.append(out)
absmax = torch.cat(absmax_list, dim=0)
out = torch.cat(out_list, dim=0)
return out, absmax, shape
if __name__ == "__main__":
for blocksize in [64, 128, 256, 512, 1024]:
for i in range(10):
x = torch.randn(4096, 4096, dtype=torch.float16, device='cuda') / max(i * 10, 1)
xq, state = quantize(x, blocksize)
xd = dequantize(xq, state.absmax, blocksize, state)
err = torch.abs(x - xd).mean().item()
print(f"Error: {err}, x.mean: {x.abs().mean()}")