高性能 :DeepSeek-V3 inference 推理时反量化实现 fp8_cast_bf16

FP8 (8 bits) & FP16 (16 bits)

  • FP8 和 BF16 都是浮点数格式(floating-point formats),float通过科学计数法表示数据,float = [符号位+指数位+系数位]
FP8 (8 bits):SEEEMMMMFP16 (16 bits):SEEEEEMMMMMMMMMM
S (1 bit)S (1 bit)
EEE (3 bits)EEEEE (5 bits)
MMMM (4 bits)MMMMMMMMMM (10 bits)
  • FP8:1位符号位、3位指数位、4位尾数位。
  • FP16:1位符号位、5位指数位、10位尾数位。
特性FP8BF16
位数8 位16 位
存储需求非常低低(但高于 FP8)
精度精度非常低,仅适合低精度计算较低的精度,但比 FP8 精度高
范围较小的数值范围与 FP32 相似,具有广泛的数值范围
主要用途主要用于训练中的权重表示主要用于训练和推理,尤其适用于加速机器学习
优点极大的存储节省和计算加速适用于大规模深度学习模型,精度损失较小

fp8_cast_bf16

  • FP8到BF16转换: 主要通过weight_dequant函数将FP8权重转换为BF16格式。
import os  # 导入操作系统接口模块,用于文件和目录操作
import json  # 导入JSON模块,用于读取和写入JSON格式的数据
from argparse import ArgumentParser  # 导入ArgumentParser类,用于命令行参数解析
from glob import glob  # 导入glob模块,用于文件路径模式匹配
from tqdm import tqdm  # 导入tqdm模块,用于显示进度条

import torch  # 导入PyTorch库
from safetensors.torch import load_file, save_file  # 从safetensors库导入load_file和save_file函数

from kernel import weight_dequant  # 从kernel模块导入weight_dequant函数,用于权重解量化

def main(fp8_path, bf16_path):
    """
    将FP8权重转换为BF16并保存转换后的权重。

    该函数从指定的目录读取FP8权重,将其转换为BF16格式,
    并将转换后的权重保存到另一个指定的目录。它还更新了
    模型索引文件,反映出这些更改。

    参数:
    fp8_path (str): 存放FP8权重和模型索引文件的目录路径。
    bf16_path (str): 保存转换后的BF16权重的目录路径。

    异常:
    KeyError: 如果缺少所需的scale_inv张量,则会引发此异常。

    注意:
    - 假定FP8权重存储为safetensor文件。
    - 该函数缓存已加载的safetensor文件以优化内存使用。
    - 函数更新模型索引文件,删除对scale_inv张量的引用。
    """
    # 设置默认数据类型为bfloat16
    torch.set_default_dtype(torch.bfloat16)
    os.makedirs(bf16_path, exist_ok=True)  # 如果输出目录不存在,则创建它
    model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")  # 模型索引文件路径
    with open(model_index_file, "r") as f:
        model_index = json.load(f)  # 读取模型索引文件
    weight_map = model_index["weight_map"]  # 获取权重映射

    # 用于缓存已加载的safetensor文件
    loaded_files = {}
    fp8_weight_names = []  # 用于存储FP8权重的名称

    def get_tensor(tensor_name):
        """
        从缓存的safetensor文件中检索张量,如果没有缓存则从磁盘加载。

        参数:
            tensor_name (str): 要检索的张量名称。

        返回:
            torch.Tensor: 检索到的张量。

        异常:
            KeyError: 如果在safetensor文件中找不到指定的张量,则引发此异常。
        """
        file_name = weight_map[tensor_name]  # 获取该张量所在的文件名
        if file_name not in loaded_files:  # 如果该文件未加载
            file_path = os.path.join(fp8_path, file_name)  # 构建文件路径
            loaded_files[file_name] = load_file(file_path, device="cuda")  # 加载文件并缓存
        return loaded_files[file_name][tensor_name]  # 返回缓存的张量

    # 获取所有safetensor文件路径,并按字母排序
    safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
    safetensor_files.sort()

    # 遍历所有的safetensor文件
    for safetensor_file in tqdm(safetensor_files):
        file_name = os.path.basename(safetensor_file)  # 获取文件名
        current_state_dict = load_file(safetensor_file, device="cuda")  # 加载当前safetensor文件
        loaded_files[file_name] = current_state_dict  # 将文件缓存起来
        
        new_state_dict = {}  # 用于存储转换后的新权重字典
        for weight_name, weight in current_state_dict.items():  # 遍历文件中的所有权重
            if weight_name.endswith("_scale_inv"):  # 如果权重是scale_inv,跳过
                continue
            elif weight.element_size() == 1:  # 如果权重是FP8(即1字节)
                scale_inv_name = f"{weight_name}_scale_inv"  # 对应的scale_inv张量名称
                try:
                    # 尝试获取对应的scale_inv张量
                    scale_inv = get_tensor(scale_inv_name)
                    fp8_weight_names.append(weight_name)  # 将FP8权重名称记录下来
                    new_state_dict[weight_name] = weight_dequant(weight, scale_inv)  # 转换为BF16
                except KeyError:
                    # 如果没有找到scale_inv张量,则跳过转换
                    print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
                    new_state_dict[weight_name] = weight  # 保留原始权重
            else:
                new_state_dict[weight_name] = weight  # 如果不是FP8,直接保留原始权重
        
        # 保存转换后的权重
        new_safetensor_file = os.path.join(bf16_path, file_name)
        save_file(new_state_dict, new_safetensor_file)
        
        # 内存管理:保持仅2个最近使用的文件
        if len(loaded_files) > 2:
            oldest_file = next(iter(loaded_files))  # 获取最老的文件
            del loaded_files[oldest_file]  # 删除最老的文件
            torch.cuda.empty_cache()  # 清理缓存

    # 更新模型索引文件
    new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
    for weight_name in fp8_weight_names:  # 遍历所有FP8权重
        scale_inv_name = f"{weight_name}_scale_inv"  # 对应的scale_inv名称
        if scale_inv_name in weight_map:
            weight_map.pop(scale_inv_name)  # 从weight_map中删除scale_inv权重
    with open(new_model_index_file, "w") as f:
        json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)  # 保存更新后的索引文件

if __name__ == "__main__":
    # 设置命令行参数解析
    parser = ArgumentParser()
    parser.add_argument("--input-fp8-hf-path", type=str, required=True)  # 输入FP8权重路径
    parser.add_argument("--output-bf16-hf-path", type=str, required=True)  # 输出BF16权重路径
    args = parser.parse_args()
    main(args.input_fp8_hf_path, args.output_bf16_hf_path)  # 调用主函数进行转换

weight_dequant

from typing import Tuple
import torch
import triton
import triton.language as tl # Triton 语言(Triton Language)允许用户在 GPU 上编写高效的并行计算内核https://github.com/triton-lang/triton
from triton import Config
  • weight_dequant 函数用于将量化的权重张量(x)进行反量化处理,恢复到浮动值。以下是注释的解释:
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
    """
    Dequantizes the given weight tensor using the provided scale tensor.

    Args:
        x (torch.Tensor): The quantized weight tensor of shape (M, N).
        s (torch.Tensor): The scale tensor of shape (M, N).
        block_size (int, optional): The block size to use for dequantization. Defaults to 128.

    Returns:
        torch.Tensor: The dequantized weight tensor of the same shape as `x`.

    Raises:
        AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
    """
    assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous' # 确保输入张量是连续的(即内存布局连续)
    
    
    assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions' # 确保输入张量 x 和 s 都是二维的
    
    M, N = x.size() # 获取输入张量 x 的尺寸 M (行数) 和 N (列数)

    # 创建一个和 x 形状相同的新张量 y,用来保存反量化后的结果
    y = torch.empty_like(x, dtype=torch.get_default_dtype())

    # 定义一个 grid 函数来计算 triton 内核所需的网格大小
    # triton.cdiv 是向上取整除法,用来确保我们分配足够的线程处理每个块
    grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))

    # 调用 triton 内核 `weight_dequant_kernel` 进行反量化操作
    # 将 quantized weight `x` 和 scale `s` 与结果张量 `y` 一起传递给内核
    # `M`, `N`, `block_size` 作为额外的参数传递
    weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)

    # 返回反量化后的张量 y
    return y
  • 计算网格大小:
    • grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])): 使用 triton.cdiv 来计算块的数量。triton.cdiv 是向上取整除法,用于确定每个维度需要多少个块来处理 MN 大小的数据。meta['BLOCK_SIZE']) 是每个块处理的元素数量(默认值为 128)。

weight_dequant_kernel

@triton.jit  # 使用 Triton 编译器将此函数编译为高效的 GPU 内核
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
    """
    Dequantizes weights using the provided scaling factors and stores the result.

    Args:
        x_ptr (tl.pointer): Pointer to the quantized weights.
        s_ptr (tl.pointer): Pointer to the scaling factors.
        y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
        M (int): Number of rows in the weight matrix.
        N (int): Number of columns in the weight matrix.
        BLOCK_SIZE (tl.constexpr): Size of the block for tiling.

    Returns:
        None
    """
    
    # 获取当前线程在程序中的编号
    pid_m = tl.program_id(axis=0)  # 获取当前行维度上的线程编号,pid_m 和 pid_n 的范围由矩阵的尺寸 M 和 N,以及线程块的大小 BLOCK_SIZE 决定
    pid_n = tl.program_id(axis=1)  # 获取当前列维度上的线程编号,pid_m 的值从 0 到 ceil(M / BLOCK_SIZE) - 1
    
    # 计算矩阵列的块数
    n = tl.cdiv(N, BLOCK_SIZE)  # 使用向上取整除法计算列方向上的块数
    
    # 计算当前线程块在行和列方向上的偏移量
    offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)  # 当前块在行方向的偏移量
    offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)  # 当前块在列方向的偏移量
    # 将行和列的偏移量组合成一个二维的索引数组
    offs = offs_m[:, None] * N + offs_n[None, :]  # 将行和列偏移量结合,得到每个元素的全局索引,offs_m[:, None]形状会变成 (BLOCK_SIZE, 1),相加广播后变为(BLOCK_SIZE, BLOCK_SIZE)
    
    # 使用掩码保证我们不会超出矩阵的边界
    mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)  # 掩码,确保线程不会访问超出矩阵范围的数据
    
    # 加载量化后的权重数据(量化后的值)
    x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)  # 从内存中加载量化后的数据,并转换为 float32 类型
    
    # 加载缩放因子
    s = tl.load(s_ptr + pid_m * n + pid_n)  # 从内存中加载对应的缩放因子,s_ptr是指向缩放因子数组的指针,pid_m * n + pid_n计算出当前线程块在缩放因子数组中的位置。
    
    # 执行去量化操作:去量化 = 量化值 * 缩放因子
    y = x * s  # 去量化的计算公式
    
    # 将去量化后的数据存储到输出缓存中
    tl.store(y_ptr + offs, y, mask=mask)  # 将去量化后的值存储到输出内存中,使用掩码确保数据存储在合法的范围内,`offs` 是索引,`mask=mask` 确保只有合法的元素被存储
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值