来源: 详细解释内容可参考 EAGLE投机采样
投机采样
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# 初始化模型和tokenizer
draft_model_name = "google/gemma-2b" # 选用一个轻量级模型
target_model_name = "meta-llama/Llama-2-7b-chat-hf" # 选用一个性能更好的模型
draft_model = AutoModelForCausalLM.from_pretrained(draft_model_name, device_map="auto")
target_model = AutoModelForCausalLM.from_pretrained(target_model_name, device_map="auto")
draft_tokenizer = AutoTokenizer.from_pretrained(draft_model_name)
target_tokenizer = AutoTokenizer.from_pretrained(target_model_name)
def speculative_decode(prompt, draft_model, target_model, draft_tokenizer, target_tokenizer, num_draft_tokens=5, acceptance_threshold=0.8):
"""
简单的投机采样实现
Args:
prompt: 输入prompt (字符串)
draft_model: Draft Model
target_model: Target Model
draft_tokenizer: Draft Model的tokenizer
target_tokenizer: Target Model的tokenizer
num_draft_tokens: Draft Model生成的token数量
acceptance_threshold: 接受Draft token的概率阈值
Returns:
生成的文本 (字符串)
"""
# 1. Draft 阶段
draft_input = draft_tokenizer(prompt, return_tensors="pt").to(draft_model.device)
draft_output = draft_model.generate(**draft_input, max_new_tokens=num_draft_tokens)
draft_tokens = draft_output[:, draft_input['input_ids'].shape[-1]:]
draft_text = draft_tokenizer.batch_decode(draft_tokens, skip_special_tokens=True)[0]
# 2. Verify 阶段
target_input = target_tokenizer(prompt, return_tensors="pt").to(target_model.device)
target_logits = target_model(**target_input).logits
initial_target_token = torch.argmax(target_logits[:, -1, :], dim=-1) # 获取第一个token的预测
accepted_tokens = [initial_target_token.item()] # 存储接受的token, 初始化第一个token
rejected_indices = []
# 迭代验证Draft Model生成的token
for i in range(num_draft_tokens):
# 构建上下文,包含prompt和之前接受的token
context_tokens = target_tokenizer(prompt, return_tensors="pt").to(target_model.device)['input_ids']
context_tokens = torch.cat([context_tokens, torch.tensor([[accepted_tokens[j] for j in range(len(accepted_tokens))]])], dim=1)
target_logits = target_model(**{"input_ids": context_tokens}).logits # 注意使用context_tokens作为输入
target_probs = torch.softmax(target_logits[:, -1, :], dim=-1)
target_prob_for_draft_token = target_probs[0, draft_tokens[0, i]].item()
if target_prob_for_draft_token >= acceptance_threshold:
accepted_tokens.append(draft_tokens[0, i].item())
else:
rejected_indices.append(i)
break # 一旦有token被拒绝,停止验证
# 3. 生成剩余部分 (如果还有未验证的token)
if rejected_indices:
# 在被拒绝的token位置,使用Target Model生成新的token
context_tokens = target_tokenizer(prompt, return_tensors="pt").to(target_model.device)['input_ids']
context_tokens = torch.cat([context_tokens, torch.tensor([[accepted_tokens[j] for j in range(len(accepted_tokens))]])], dim=1)
target_output = target_model.generate(**{"input_ids": context_tokens}, max_new_tokens=1) # 生成一个新token
new_token = target_output[:, context_tokens.shape[-1]:]
accepted_tokens.append(new_token[0, 0].item()) # 将新生成的token添加到accepted_tokens
# 将所有接受的token转换为文本
generated_text = target_tokenizer.batch_decode([accepted_tokens], skip_special_tokens=True)[0]
return generated_text
# 示例用法
prompt = "The capital of France is"
generated_text = speculative_decode(prompt, draft_model, target_model, draft_tokenizer, target_tokenizer)
print(f"Generated text: {generated_text}")
模块分析
导入库
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
torch:PyTorch深度学习框架,用于张量计算和模型训练。AutoModelForCausalLM:Hugging Face Transformers提供的自动加载因果语言模型的类。AutoTokenizer:Hugging Face Transformers提供的自动加载tokenizer的类。
模型初始化
draft_model_name = "google/gemma-2b"
target_model_name = "meta-llama/Llama-2-7b-chat-hf"
draft_model = AutoModelForCausalLM.from_pretrained(draft_model_name, device_map="auto")
target_model = AutoModelForCausalLM.from_pretrained(target_model_name, device_map="auto")
draft_tokenizer = AutoTokenizer.from_pretrained(draft_model_name)
target_tokenizer = AutoTokenizer.from_pretrained(target_model_name)
draft_model_name和target_model_name:分别指定轻量级草稿模型和更强大的目标模型的Hugging Face模型名称。AutoModelForCausalLM.from_pretrained:加载预训练的语言模型,device_map="auto"自动分配设备(GPU/CPU)。AutoTokenizer.from_pretrained:加载与模型对应的tokenizer。
投机采样函数
def speculative_decode(prompt, draft_model, target_model, draft_tokenizer, target_tokenizer, num_draft_tokens=5, acceptance_threshold=0.8):
- 参数说明:
prompt:输入的文本提示。draft_model和target_model:草稿模型和目标模型。draft_tokenizer和target_tokenizer:对应的tokenizer。num_draft_tokens:草稿模型生成的token数量。acceptance_threshold:接受草稿token的概率阈值。
Draft阶段
draft_input = draft_tokenizer(prompt, return_tensors="pt").to(draft_model.device)
draft_output = draft_model.generate(**draft_input, max_new_tokens=num_draft_tokens)
draft_tokens = draft_output[:, draft_input['input_ids'].shape[-1]:]
draft_text = draft_tokenizer.batch_decode(draft_tokens, skip_special_tokens=True)[0]
draft_tokenizer:将输入文本转换为模型输入的张量。draft_model.generate:生成指定数量的token。draft_tokens:提取生成的token(排除输入部分)。draft_text:将生成的token解码为文本。
Verify阶段
target_input = target_tokenizer(prompt, return_tensors="pt").to(target_model.device)
target_logits = target_model(**target_input).logits
initial_target_token = torch.argmax(target_logits[:, -1, :], dim=-1)
accepted_tokens = [initial_target_token.item()]
rejected_indices = []
target_tokenizer:将输入文本转换为目标模型的输入张量。target_model:获取目标模型的输出logits。initial_target_token:获取目标模型预测的第一个token。accepted_tokens:初始化接受的token列表。rejected_indices:存储被拒绝的token索引。
验证草稿token
for i in range(num_draft_tokens):
context_tokens = target_tokenizer(prompt, return_tensors="pt").to(target_model.device)['input_ids']
context_tokens = torch.cat([context_tokens, torch.tensor([[accepted_tokens[j] for j in range(len(accepted_tokens))]])], dim=1)
target_logits = target_model(**{"input_ids": context_tokens}).logits
target_probs = torch.softmax(target_logits[:, -1, :], dim=-1)
target_prob_for_draft_token = target_probs[0, draft_tokens[0, i]].item()
if target_prob_for_draft_token >= acceptance_threshold:
accepted_tokens.append(draft_tokens[0, i].item())
else:
rejected_indices.append(i)
break
context_tokens:构建包含输入和已接受token的上下文。target_model:计算目标模型对当前上下文的预测概率。target_prob_for_draft_token:获取草稿token在目标模型中的概率。- 如果概率高于阈值,接受该token;否则拒绝并终止验证。
生成剩余部分
if rejected_indices:
context_tokens = target_tokenizer(prompt, return_tensors="pt").to(target_model.device)['input_ids']
context_tokens = torch.cat([context_tokens, torch.tensor([[accepted_tokens[j] for j in range(len(accepted_tokens))]])], dim=1)
target_output = target_model.generate(**{"input_ids": context_tokens}, max_new_tokens=1)
new_token = target_output[:, context_tokens.shape[-1]:]
accepted_tokens.append(new_token[0, 0].item())
- 如果存在被拒绝的token,使用目标模型生成一个新的token。
target_model.generate:生成一个新token。new_token:提取生成的token并添加到接受列表。
输出结果
generated_text = target_tokenizer.batch_decode([accepted_tokens], skip_special_tokens=True)[0]
return generated_text
batch_decode:将接受的token解码为文本。- 返回生成的文本。
示例用法
prompt = "The capital of France is"
generated_text = speculative_decode(prompt, draft_model, target_model, draft_tokenizer, target_tokenizer)
print(f"Generated text: {generated_text}")
- 输入提示文本,调用投机采样函数生成文本并打印结果。
EAGLE
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
# 假设我们使用困惑度作为特征,并简化特征预测模型
class FeaturePredictor(nn.Module):
def __init__(self, hidden_size):
super(FeaturePredictor, self).__init__()
self.linear = nn.Linear(hidden_size, 1) # 预测困惑度
def forward(self, hidden_states):
# hidden_states: (batch_size, sequence_length, hidden_size)
perplexity = self.linear(hidden_states).squeeze(-1) # (batch_size, sequence_length)
return perplexity
def train_feature_predictor(target_model, tokenizer, feature_predictor, num_epochs=3, learning_rate=1e-4):
"""
训练特征预测器,使用Target Model的hidden states和困惑度作为训练数据
Args:
target_model: Target Model
tokenizer: Target Model的tokenizer
feature_predictor: 特征预测模型
num_epochs: 训练轮数
learning_rate: 学习率
"""
optimizer = optim.Adam(feature_predictor.parameters(), lr=learning_rate)
criterion = nn.MSELoss()
# 生成一些训练数据
texts = ["The quick brown fox jumps over the lazy dog.",
"The capital of France is Paris.",
"Machine learning is a fascinating field.",
"Coding is fun and challenging."]
for epoch in range(num_epochs):
for text in texts:
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(target_model.device)
with torch.no_grad():
outputs = target_model(**inputs, output_hidden_states=True)
hidden_states = outputs.hidden_states[-1] # 使用最后一层的hidden states
# 计算困惑度 (作为ground truth)
logits = outputs.logits
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = inputs['input_ids'][:, 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
perplexity = torch.exp(loss)
# 使用hidden states预测困惑度
predicted_perplexity = feature_predictor(hidden_states[:, :-1, :]) # 预测除了第一个token之外的perplexity
# 计算损失并更新模型
loss = criterion(predicted_perplexity, torch.full_like(predicted_perplexity, perplexity.item()))
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}, Loss: {loss.item()}")
def eagle_speculative_decode(prompt, draft_model, target_model, draft_tokenizer, target_tokenizer, feature_predictor, num_draft_tokens=5, acceptance_threshold=0.8, alpha=0.5):
"""
EAGLE投机采样实现
Args:
prompt: 输入prompt (字符串)
draft_model: Draft Model
target_model: Target Model
draft_tokenizer: Draft Model的tokenizer
target_tokenizer: Target Model的tokenizer
feature_predictor: 特征预测模型
num_draft_tokens: Draft Model生成的token数量
acceptance_threshold: 接受Draft token的概率阈值
alpha: 权重系数,用于平衡Draft Model和特征预测
Returns:
生成的文本 (字符串)
"""
# 1. Draft 阶段
draft_input = draft_tokenizer(prompt, return_tensors="pt").to(draft_model.device)
draft_output = draft_model.generate(**draft_input, max_new_tokens=num_draft_tokens, output_hidden_states=True, return_dict_in_generate=True)
draft_tokens = draft_output.sequences[:, draft_input['input_ids'].shape[-1]:]
draft_text = draft_tokenizer.batch_decode(draft_tokens, skip_special_tokens=True)[0]
draft_hidden_states = draft_output.hidden_states[-1] # 获取Draft Model的hidden states
# 2. 特征预测 (困惑度)
predicted_perplexities = feature_predictor(draft_hidden_states[:, draft_input['input_ids'].shape[-1]:, :]).detach().cpu().numpy() # 预测每个token的困惑度
# 3. Verify 阶段
target_input = target_tokenizer(prompt, return_tensors="pt").to(target_model.device)
target_logits = target_model(**target_input).logits
initial_target_token = torch.argmax(target_logits[:, -1, :], dim=-1) # 获取第一个token的预测
accepted_tokens = [initial_target_token.item()] # 存储接受的token, 初始化第一个token
rejected_indices = []
# 迭代验证Draft Model生成的token
for i in range(num_draft_tokens):
# 构建上下文,包含prompt和之前接受的token
context_tokens = target_tokenizer(prompt, return_tensors="pt").to(target_model.device)['input_ids']
context_tokens = torch.cat([context_tokens, torch.tensor([[accepted_tokens[j] for j in range(len(accepted_tokens))]])], dim=1)
target_logits = target_model(**{"input_ids": context_tokens}).logits # 注意使用context_tokens作为输入
target_probs = torch.softmax(target_logits[:, -1, :], dim=-1)
target_prob_for_draft_token = target_probs[0, draft_tokens[0, i]].item()
# 4. 调整采样概率 (这里简化为直接使用困惑度作为调整)
# 假设Target Model期望的困惑度较低,因此困惑度越低,越容易接受
feature_prob = 1.0 - np.clip(predicted_perplexities[0, i] / 10.0, 0.0, 1.0) # 将困惑度映射到0-1之间的概率
# 加权平均采样概率
adjusted_prob = alpha * target_prob_for_draft_token + (1 - alpha) * feature_prob
if adjusted_prob >= acceptance_threshold:
accepted_tokens.append(draft_tokens[0, i].item())
else:
rejected_indices.append(i)
break # 一旦有token被拒绝,停止验证
# 5. 生成剩余部分 (如果还有未验证的token)
if rejected_indices:
# 在被拒绝的token位置,使用Target Model生成新的token
context_tokens = target_tokenizer(prompt, return_tensors="pt").to(target_model.device)['input_ids']
context_tokens = torch.cat([context_tokens, torch.tensor([[accepted_tokens[j] for j in range(len(accepted_tokens))]])], dim=1)
target_output = target_model.generate(**{"input_ids": context_tokens}, max_new_tokens=1) # 生成一个新token
new_token = target_output[:, context_tokens.shape[-1]:]
accepted_tokens.append(new_token[0, 0].item()) # 将新生成的token添加到accepted_tokens
# 将所有接受的token转换为文本
generated_text = target_tokenizer.batch_decode([accepted_tokens], skip_special_tokens=True)[0]
return generated_text
# 初始化模型和tokenizer (这里简化为使用同一个模型)
model_name = "google/gemma-2b"
draft_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
target_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
draft_tokenizer = AutoTokenizer.from_pretrained(model_name)
target_tokenizer = AutoTokenizer.from_pretrained(model_name)
# 初始化特征预测模型
hidden_size = draft_model.config.hidden_size
feature_predictor = FeaturePredictor(hidden_size).to(draft_model.device)
# 训练特征预测模型
train_feature_predictor(target_model, target_tokenizer, feature_predictor)
# 示例用法
prompt = "The capital of France is"
generated_text = eagle_speculative_decode(prompt, draft_model, target_model, draft_tokenizer, target_tokenizer, feature_predictor)
print(f"Generated text (EAGLE): {generated_text}")

545

被折叠的 条评论
为什么被折叠?



