投机采样 Speculative Decoding -- EAGLE


来源: 详细解释内容可参考 EAGLE投机采样

投机采样



import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# 初始化模型和tokenizer
draft_model_name = "google/gemma-2b" # 选用一个轻量级模型
target_model_name = "meta-llama/Llama-2-7b-chat-hf" # 选用一个性能更好的模型

draft_model = AutoModelForCausalLM.from_pretrained(draft_model_name, device_map="auto")
target_model = AutoModelForCausalLM.from_pretrained(target_model_name, device_map="auto")

draft_tokenizer = AutoTokenizer.from_pretrained(draft_model_name)
target_tokenizer = AutoTokenizer.from_pretrained(target_model_name)

def speculative_decode(prompt, draft_model, target_model, draft_tokenizer, target_tokenizer, num_draft_tokens=5, acceptance_threshold=0.8):
    """
    简单的投机采样实现

    Args:
        prompt: 输入prompt (字符串)
        draft_model: Draft Model
        target_model: Target Model
        draft_tokenizer: Draft Model的tokenizer
        target_tokenizer: Target Model的tokenizer
        num_draft_tokens: Draft Model生成的token数量
        acceptance_threshold: 接受Draft token的概率阈值

    Returns:
        生成的文本 (字符串)
    """

    # 1. Draft 阶段
    draft_input = draft_tokenizer(prompt, return_tensors="pt").to(draft_model.device)
    draft_output = draft_model.generate(**draft_input, max_new_tokens=num_draft_tokens)
    draft_tokens = draft_output[:, draft_input['input_ids'].shape[-1]:]
    draft_text = draft_tokenizer.batch_decode(draft_tokens, skip_special_tokens=True)[0]

    # 2. Verify 阶段
    target_input = target_tokenizer(prompt, return_tensors="pt").to(target_model.device)
    target_logits = target_model(**target_input).logits
    initial_target_token = torch.argmax(target_logits[:, -1, :], dim=-1)  # 获取第一个token的预测

    accepted_tokens = [initial_target_token.item()]  # 存储接受的token, 初始化第一个token
    rejected_indices = []

    # 迭代验证Draft Model生成的token
    for i in range(num_draft_tokens):
        # 构建上下文,包含prompt和之前接受的token
        context_tokens = target_tokenizer(prompt, return_tensors="pt").to(target_model.device)['input_ids']
        context_tokens = torch.cat([context_tokens, torch.tensor([[accepted_tokens[j] for j in range(len(accepted_tokens))]])], dim=1)

        target_logits = target_model(**{"input_ids": context_tokens}).logits  # 注意使用context_tokens作为输入
        target_probs = torch.softmax(target_logits[:, -1, :], dim=-1)
        target_prob_for_draft_token = target_probs[0, draft_tokens[0, i]].item()

        if target_prob_for_draft_token >= acceptance_threshold:
            accepted_tokens.append(draft_tokens[0, i].item())
        else:
            rejected_indices.append(i)
            break # 一旦有token被拒绝,停止验证

    # 3. 生成剩余部分 (如果还有未验证的token)
    if rejected_indices:
        # 在被拒绝的token位置,使用Target Model生成新的token
        context_tokens = target_tokenizer(prompt, return_tensors="pt").to(target_model.device)['input_ids']
        context_tokens = torch.cat([context_tokens, torch.tensor([[accepted_tokens[j] for j in range(len(accepted_tokens))]])], dim=1)

        target_output = target_model.generate(**{"input_ids": context_tokens}, max_new_tokens=1)  # 生成一个新token
        new_token = target_output[:, context_tokens.shape[-1]:]
        accepted_tokens.append(new_token[0, 0].item()) # 将新生成的token添加到accepted_tokens

    # 将所有接受的token转换为文本
    generated_text = target_tokenizer.batch_decode([accepted_tokens], skip_special_tokens=True)[0]

    return generated_text

# 示例用法
prompt = "The capital of France is"
generated_text = speculative_decode(prompt, draft_model, target_model, draft_tokenizer, target_tokenizer)
print(f"Generated text: {generated_text}")

模块分析

导入库
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
  • torch:PyTorch深度学习框架,用于张量计算和模型训练。
  • AutoModelForCausalLM:Hugging Face Transformers提供的自动加载因果语言模型的类。
  • AutoTokenizer:Hugging Face Transformers提供的自动加载tokenizer的类。

模型初始化

draft_model_name = "google/gemma-2b"
target_model_name = "meta-llama/Llama-2-7b-chat-hf"

draft_model = AutoModelForCausalLM.from_pretrained(draft_model_name, device_map="auto")
target_model = AutoModelForCausalLM.from_pretrained(target_model_name, device_map="auto")

draft_tokenizer = AutoTokenizer.from_pretrained(draft_model_name)
target_tokenizer = AutoTokenizer.from_pretrained(target_model_name)
  • draft_model_nametarget_model_name:分别指定轻量级草稿模型和更强大的目标模型的Hugging Face模型名称。
  • AutoModelForCausalLM.from_pretrained:加载预训练的语言模型,device_map="auto"自动分配设备(GPU/CPU)。
  • AutoTokenizer.from_pretrained:加载与模型对应的tokenizer。

投机采样函数

def speculative_decode(prompt, draft_model, target_model, draft_tokenizer, target_tokenizer, num_draft_tokens=5, acceptance_threshold=0.8):
  • 参数说明:
    • prompt:输入的文本提示。
    • draft_modeltarget_model:草稿模型和目标模型。
    • draft_tokenizertarget_tokenizer:对应的tokenizer。
    • num_draft_tokens:草稿模型生成的token数量。
    • acceptance_threshold:接受草稿token的概率阈值。

Draft阶段
draft_input = draft_tokenizer(prompt, return_tensors="pt").to(draft_model.device)
draft_output = draft_model.generate(**draft_input, max_new_tokens=num_draft_tokens)
draft_tokens = draft_output[:, draft_input['input_ids'].shape[-1]:]
draft_text = draft_tokenizer.batch_decode(draft_tokens, skip_special_tokens=True)[0]
  • draft_tokenizer:将输入文本转换为模型输入的张量。
  • draft_model.generate:生成指定数量的token。
  • draft_tokens:提取生成的token(排除输入部分)。
  • draft_text:将生成的token解码为文本。

Verify阶段
target_input = target_tokenizer(prompt, return_tensors="pt").to(target_model.device)
target_logits = target_model(**target_input).logits
initial_target_token = torch.argmax(target_logits[:, -1, :], dim=-1)
accepted_tokens = [initial_target_token.item()]
rejected_indices = []
  • target_tokenizer:将输入文本转换为目标模型的输入张量。
  • target_model:获取目标模型的输出logits。
  • initial_target_token:获取目标模型预测的第一个token。
  • accepted_tokens:初始化接受的token列表。
  • rejected_indices:存储被拒绝的token索引。

验证草稿token
for i in range(num_draft_tokens):
    context_tokens = target_tokenizer(prompt, return_tensors="pt").to(target_model.device)['input_ids']
    context_tokens = torch.cat([context_tokens, torch.tensor([[accepted_tokens[j] for j in range(len(accepted_tokens))]])], dim=1)

    target_logits = target_model(**{"input_ids": context_tokens}).logits
    target_probs = torch.softmax(target_logits[:, -1, :], dim=-1)
    target_prob_for_draft_token = target_probs[0, draft_tokens[0, i]].item()

    if target_prob_for_draft_token >= acceptance_threshold:
        accepted_tokens.append(draft_tokens[0, i].item())
    else:
        rejected_indices.append(i)
        break
  • context_tokens:构建包含输入和已接受token的上下文。
  • target_model:计算目标模型对当前上下文的预测概率。
  • target_prob_for_draft_token:获取草稿token在目标模型中的概率。
  • 如果概率高于阈值,接受该token;否则拒绝并终止验证。

生成剩余部分
if rejected_indices:
    context_tokens = target_tokenizer(prompt, return_tensors="pt").to(target_model.device)['input_ids']
    context_tokens = torch.cat([context_tokens, torch.tensor([[accepted_tokens[j] for j in range(len(accepted_tokens))]])], dim=1)

    target_output = target_model.generate(**{"input_ids": context_tokens}, max_new_tokens=1)
    new_token = target_output[:, context_tokens.shape[-1]:]
    accepted_tokens.append(new_token[0, 0].item())
  • 如果存在被拒绝的token,使用目标模型生成一个新的token。
  • target_model.generate:生成一个新token。
  • new_token:提取生成的token并添加到接受列表。

输出结果
generated_text = target_tokenizer.batch_decode([accepted_tokens], skip_special_tokens=True)[0]
return generated_text
  • batch_decode:将接受的token解码为文本。
  • 返回生成的文本。

示例用法

prompt = "The capital of France is"
generated_text = speculative_decode(prompt, draft_model, target_model, draft_tokenizer, target_tokenizer)
print(f"Generated text: {generated_text}")
  • 输入提示文本,调用投机采样函数生成文本并打印结果。

EAGLE


import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np

# 假设我们使用困惑度作为特征,并简化特征预测模型

class FeaturePredictor(nn.Module):
    def __init__(self, hidden_size):
        super(FeaturePredictor, self).__init__()
        self.linear = nn.Linear(hidden_size, 1) # 预测困惑度

    def forward(self, hidden_states):
        # hidden_states: (batch_size, sequence_length, hidden_size)
        perplexity = self.linear(hidden_states).squeeze(-1) # (batch_size, sequence_length)
        return perplexity

def train_feature_predictor(target_model, tokenizer, feature_predictor, num_epochs=3, learning_rate=1e-4):
    """
    训练特征预测器,使用Target Model的hidden states和困惑度作为训练数据

    Args:
        target_model: Target Model
        tokenizer: Target Model的tokenizer
        feature_predictor: 特征预测模型
        num_epochs: 训练轮数
        learning_rate: 学习率
    """
    optimizer = optim.Adam(feature_predictor.parameters(), lr=learning_rate)
    criterion = nn.MSELoss()

    # 生成一些训练数据
    texts = ["The quick brown fox jumps over the lazy dog.",
             "The capital of France is Paris.",
             "Machine learning is a fascinating field.",
             "Coding is fun and challenging."]

    for epoch in range(num_epochs):
        for text in texts:
            inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(target_model.device)
            with torch.no_grad():
                outputs = target_model(**inputs, output_hidden_states=True)
                hidden_states = outputs.hidden_states[-1] # 使用最后一层的hidden states

                # 计算困惑度 (作为ground truth)
                logits = outputs.logits
                shift_logits = logits[:, :-1, :].contiguous()
                shift_labels = inputs['input_ids'][:, 1:].contiguous()
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
                perplexity = torch.exp(loss)

            # 使用hidden states预测困惑度
            predicted_perplexity = feature_predictor(hidden_states[:, :-1, :]) # 预测除了第一个token之外的perplexity

            # 计算损失并更新模型
            loss = criterion(predicted_perplexity, torch.full_like(predicted_perplexity, perplexity.item()))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}, Loss: {loss.item()}")

def eagle_speculative_decode(prompt, draft_model, target_model, draft_tokenizer, target_tokenizer, feature_predictor, num_draft_tokens=5, acceptance_threshold=0.8, alpha=0.5):
    """
    EAGLE投机采样实现

    Args:
        prompt: 输入prompt (字符串)
        draft_model: Draft Model
        target_model: Target Model
        draft_tokenizer: Draft Model的tokenizer
        target_tokenizer: Target Model的tokenizer
        feature_predictor: 特征预测模型
        num_draft_tokens: Draft Model生成的token数量
        acceptance_threshold: 接受Draft token的概率阈值
        alpha: 权重系数,用于平衡Draft Model和特征预测

    Returns:
        生成的文本 (字符串)
    """

    # 1. Draft 阶段
    draft_input = draft_tokenizer(prompt, return_tensors="pt").to(draft_model.device)
    draft_output = draft_model.generate(**draft_input, max_new_tokens=num_draft_tokens, output_hidden_states=True, return_dict_in_generate=True)
    draft_tokens = draft_output.sequences[:, draft_input['input_ids'].shape[-1]:]
    draft_text = draft_tokenizer.batch_decode(draft_tokens, skip_special_tokens=True)[0]
    draft_hidden_states = draft_output.hidden_states[-1] # 获取Draft Model的hidden states

    # 2. 特征预测 (困惑度)
    predicted_perplexities = feature_predictor(draft_hidden_states[:, draft_input['input_ids'].shape[-1]:, :]).detach().cpu().numpy()  # 预测每个token的困惑度

    # 3. Verify 阶段
    target_input = target_tokenizer(prompt, return_tensors="pt").to(target_model.device)
    target_logits = target_model(**target_input).logits
    initial_target_token = torch.argmax(target_logits[:, -1, :], dim=-1)  # 获取第一个token的预测

    accepted_tokens = [initial_target_token.item()]  # 存储接受的token, 初始化第一个token
    rejected_indices = []

    # 迭代验证Draft Model生成的token
    for i in range(num_draft_tokens):
        # 构建上下文,包含prompt和之前接受的token
        context_tokens = target_tokenizer(prompt, return_tensors="pt").to(target_model.device)['input_ids']
        context_tokens = torch.cat([context_tokens, torch.tensor([[accepted_tokens[j] for j in range(len(accepted_tokens))]])], dim=1)

        target_logits = target_model(**{"input_ids": context_tokens}).logits  # 注意使用context_tokens作为输入
        target_probs = torch.softmax(target_logits[:, -1, :], dim=-1)
        target_prob_for_draft_token = target_probs[0, draft_tokens[0, i]].item()

        # 4. 调整采样概率 (这里简化为直接使用困惑度作为调整)
        # 假设Target Model期望的困惑度较低,因此困惑度越低,越容易接受
        feature_prob = 1.0 - np.clip(predicted_perplexities[0, i] / 10.0, 0.0, 1.0)  # 将困惑度映射到0-1之间的概率

        # 加权平均采样概率
        adjusted_prob = alpha * target_prob_for_draft_token + (1 - alpha) * feature_prob

        if adjusted_prob >= acceptance_threshold:
            accepted_tokens.append(draft_tokens[0, i].item())
        else:
            rejected_indices.append(i)
            break # 一旦有token被拒绝,停止验证

    # 5. 生成剩余部分 (如果还有未验证的token)
    if rejected_indices:
        # 在被拒绝的token位置,使用Target Model生成新的token
        context_tokens = target_tokenizer(prompt, return_tensors="pt").to(target_model.device)['input_ids']
        context_tokens = torch.cat([context_tokens, torch.tensor([[accepted_tokens[j] for j in range(len(accepted_tokens))]])], dim=1)

        target_output = target_model.generate(**{"input_ids": context_tokens}, max_new_tokens=1)  # 生成一个新token
        new_token = target_output[:, context_tokens.shape[-1]:]
        accepted_tokens.append(new_token[0, 0].item()) # 将新生成的token添加到accepted_tokens

    # 将所有接受的token转换为文本
    generated_text = target_tokenizer.batch_decode([accepted_tokens], skip_special_tokens=True)[0]

    return generated_text

# 初始化模型和tokenizer (这里简化为使用同一个模型)
model_name = "google/gemma-2b"
draft_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
target_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
draft_tokenizer = AutoTokenizer.from_pretrained(model_name)
target_tokenizer = AutoTokenizer.from_pretrained(model_name)

# 初始化特征预测模型
hidden_size = draft_model.config.hidden_size
feature_predictor = FeaturePredictor(hidden_size).to(draft_model.device)

# 训练特征预测模型
train_feature_predictor(target_model, target_tokenizer, feature_predictor)

# 示例用法
prompt = "The capital of France is"
generated_text = eagle_speculative_decode(prompt, draft_model, target_model, draft_tokenizer, target_tokenizer, feature_predictor)
print(f"Generated text (EAGLE): {generated_text}")



投机采样Speculative Sampling)是一种加速大语言模型推理的技术。以下是关于in - house(内部、自有)的投机采样的相关介绍: ### 基本概念 投机采样的核心思想是利用一个较小、速度更快的“草稿模型”(Draft Model)来生成可能的输出候选,然后让较大、更准确的“目标模型”(Target Model)对这些候选进行验证。在in - house场景下,企业或组织可以根据自身需求训练和使用适合的草稿模型和目标模型。 ### 工作流程 1. **草稿模型生成**:草稿模型基于输入生成一系列可能的输出标记(Tokens)。由于草稿模型规模较小,生成速度快。 2. **目标模型验证**:目标模型对草稿模型生成的标记进行验证。它会从草稿模型生成的序列开始处理,检查这些标记是否与自己的预测一致。 3. **接受或拒绝**:如果目标模型验证通过草稿模型生成的部分或全部标记,那么这些标记就被接受并作为最终输出的一部分;如果验证不通过,目标模型会重新生成正确的标记。 ### in - house应用优势 1. **定制化**:企业可以根据自身数据和业务需求,训练专门的草稿模型和目标模型,以更好地适应特定任务。 2. **数据安全**:在in - house环境中处理数据,避免了将敏感数据发送到外部,确保数据安全和隐私。 3. **成本控制**:通过优化模型架构和训练过程,可以降低推理成本,特别是在大规模使用的情况下。 ### 潜在挑战 1. **模型训练**:训练合适的草稿模型和目标模型需要大量的计算资源和专业知识。 2. **性能调优**:需要对草稿模型和目标模型的性能进行调优,以达到最佳的加速效果和输出质量。 ### 示例代码(伪代码) ```python # 输入 input_text = "Some input text" # 草稿模型生成 draft_model = load_draft_model() draft_output = draft_model.generate(input_text) # 目标模型验证 target_model = load_target_model() validated_output = target_model.validate(draft_output, input_text) # 最终输出 final_output = validated_output ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值