query改写微调(T5 + DPO)

1. 环境准备

pip install openai transformers sentence-transformers scikit-learn torch tensorboard datasets nltk

2. 数据生成

2.1 使用 ChatGPT API 生成改写数据对

import openai
from typing import Tuple

def generate_rewrite_with_chatgpt(prompt: str, model_name: str = 'gpt-4', temperature: float = 0.7) -> str:
    openai.api_key = 'your_openai_api_key'
    response = openai.Completion.create(
        model=model_name,
        prompt=prompt,
        temperature=temperature,
        max_tokens=60
    )
    return response.choices[0].text.strip()

def generate_positive_negative_pairs_with_chatgpt(query: str) -> Tuple[str, str]:
    positive_rewrite = generate_rewrite_with_chatgpt(f"Rewrite the following query: {query}")
    negative_rewrite = generate_rewrite_with_chatgpt(f"Generate a random sentence unrelated to: {query}")
    return positive_rewrite, negative_rewrite


2.2 数据增强

from nltk.corpus import wordnet
import random
from typing import List

def augment_data(original_texts: List[str], num_augmentations: int) -> List[str]:
    """
    使用同义词替换进行数据增强。
    
    :param original_texts: 原始文本列表,其中每个文本是一个字符串
    :param num_augmentations: 每个文本生成的增强样本数,控制增强的数量
    :return: 增强后的文本列表,包含了使用同义词替换生成的变体
    """
    augmented_texts = []
    
    # 遍历原始文本列表
    for text in original_texts:
        words = text.split()  # 将文本分割成单词列表
        
        # 生成指定数量的增强样本
        for _ in range(num_augmentations):
            augmented_words = []
            
            # 遍历文本中的每个单词
            for word in words:
                # 获取单词的同义词集合
                synonyms = wordnet.synsets(word)
                
                # 如果找到同义词,则随机选择一个同义词作为替换
                if synonyms:
                    # 选择同义词的第一个词义的第一个同义词
                    synonym = random.choice(synonyms).lemmas()[0].name()
                    augmented_words.append(synonym)
                else:
                    # 如果没有找到同义词,则保留原始单词
                    augmented_words.append(word)
            
            # 将增强的单词列表合并成一个新的文本
            new_text = ' '.join(augmented_words)
            augmented_texts.append(new_text)  # 将新生成的文本添加到结果列表中
    
    return augmented_texts


2.3 数据验证

数据验证功能通过语义相似度、Jaccard 相似度、BM25 分数和长度差异四种方法,筛选符合质量标准的改写文本。

import numpy as np
from sentence_transformers import SentenceTransformer, util
from rank_bm25 import BM25Okapi
from typing import List, Dict

# 1. 语义相似度验证
def validate_rewrites(original_query: str, generated_rewrites: List[str], similarity_threshold: float = 0.7) -> Dict[str, float]:
    """
    使用 SentenceTransformer 模型计算生成文本的语义相似度,并筛选出符合阈值的改写文本。
    
    :param original_query: 原始查询文本
    :param generated_rewrites: 生成的改写文本列表
    :param similarity_threshold: 语义相似度阈值,默认值为 0.7
    :return: 包含符合条件的改写文本及其语义相似度的字典
    """
    model = SentenceTransformer('all-MiniLM-L6-v2')
    original_embedding = model.encode(original_query, convert_to_tensor=True)
    results = {}

    # 使用 numpy 进行批量计算
    rewrite_embeddings = model.encode(generated_rewrites, convert_to_tensor=True)
    similarities = util.pytorch_cos_sim(original_embedding, rewrite_embeddings).cpu().numpy().flatten()

    for rewrite, similarity in zip(generated_rewrites, similarities):
        if similarity >= similarity_threshold:
            results[rewrite] = similarity

    return results

# 2. Jaccard 相似度验证
def jaccard_similarity(query: str, rewrite: str) -> float:
    """
    计算两个文本的 Jaccard 相似度。
    
    :param query: 原始查询文本
    :param rewrite: 改写文本
    :return: Jaccard 相似度
    """
    query_set = set(query.split())
    rewrite_set = set(rewrite.split())
    intersection = len(query_set & rewrite_set)
    union = len(query_set | rewrite_set)
    return intersection / union

def validate_rewrites_jaccard(original_query: str, generated_rewrites: List[str], similarity_threshold: float = 0.5) -> Dict[str, float]:
    """
    使用 Jaccard 相似度验证生成的改写文本,并筛选出符合阈值的改写文本。
    
    :param original_query: 原始查询文本
    :param generated_rewrites: 生成的改写文本列表
    :param similarity_threshold: Jaccard 相似度阈值,默认值为 0.5
    :return: 包含符合条件的改写文本及其 Jaccard 相似度的字典
    """
    results = {}
    for rewrite in generated_rewrites:
        similarity = jaccard_similarity(original_query, rewrite)
        if similarity >= similarity_threshold:
            results[rewrite] = similarity
    return results

# 3. BM25 验证
def calculate_bm25(query: str, documents: List[str]) -> List[float]:
    """
    使用 BM25 模型计算查询与多个文档之间的相关性分数。
    
    :param query: 查询文本
    :param documents: 文档文本列表
    :return: 查询与文档之间的 BM25 分数列表
    """
    tokenized_query = query.split()
    tokenized_docs = [doc.split() for doc in documents]
    bm25 = BM25Okapi(tokenized_docs)
    scores = bm25.get_scores(tokenized_query)
    return scores

def validate_rewrites_bm25(original_query: str, generated_rewrites: List[str], score_threshold: float = 1.0) -> Dict[str, float]:
    """
    使用 BM25 验证生成的改写文本的有效性,并筛选出符合阈值的改写文本。
    
    :param original_query: 原始查询文本
    :param generated_rewrites: 生成的改写文本列表
    :param score_threshold: BM25 分数阈值,默认值为 1.0
    :return: 包含符合条件的改写文本及其 BM25 分数的字典
    """
    all_rewrites = [original_query] + generated_rewrites
    scores = calculate_bm25(original_query, all_rewrites)
    results = {}
    for rewrite, score in zip(generated_rewrites, scores[1:]):
        if score >= score_threshold:
            results[rewrite] = score
    return results

# 4. 长度差异验证
def validate_length_difference(original_query: str, generated_rewrites: List[str], length_threshold: float = 0.3) -> Dict[str, float]:
    """
    验证改写文本与原始文本的长度差异,并筛选出符合阈值的改写文本。
    
    :param original_query: 原始查询文本
    :param generated_rewrites: 生成的改写文本列表
    :param length_threshold: 长度差异阈值,默认值为 0.3
    :return: 包含符合条件的改写文本及其长度差异的字典
    """
    original_length = len(original_query.split())
    results = {}

    # 将生成的改写文本转换为 numpy 数组
    rewrite_lengths = np.array([len(rewrite.split()) for rewrite in generated_rewrites])

    # 计算长度差异
    length_diffs = np.abs(original_length - rewrite_lengths)

    # 计算归一化的长度差异
    normalized_diffs = length_diffs / np.maximum(original_length, rewrite_lengths)

    # 将符合条件的改写文本添加到结果中
    for rewrite, normalized_diff in zip(generated_rewrites, normalized_diffs):
        if normalized_diff <= length_threshold:
            results[rewrite] = normalized_diff

    return results

# 5. 过滤和分类
def filter_valid_rewrites(rewrites: List[str], original_query: str, similarity_threshold: float = 0.7, length_threshold: float = 0.3, bm25_threshold: float = 1.0) -> List[str]:
    """
    综合应用所有验证方法,筛选符合条件的改写文本。
    
    :param rewrites: 生成的改写文本列表
    :param original_query: 原始查询文本
    :param similarity_threshold: 语义相似度阈值
    :param length_threshold: 长度差异阈值
    :param bm25_threshold: BM25 分数阈值
    :return: 符合所有条件的改写文本列表
    """
    semantic_similarities = validate_rewrites(original_query, rewrites, similarity_threshold)
    length_differences = validate_length_difference(original_query, rewrites, length_threshold)
    bm25_scores = validate_rewrites_bm25(original_query, rewrites, bm25_threshold)
    
    valid_rewrites = []
    for rewrite in rewrites:
        if rewrite in semantic_similarities and rewrite in length_differences and rewrite in bm25_scores:
            valid_rewrites.append(rewrite)
    
    return valid_rewrites



2.4 数据集示例

# law_legal_data.csv
query,rewrite
"What is the penalty for breach of contract?", "What are the consequences for breaking a contract?"
"How can I file for a divorce?", "What is the process for initiating a divorce?"

3. 数据集准备

law_legal_data.csv 中的数据需要进一步处理才能用于 T5 模型的训练。T5 模型通常需要特定的输入格式和数据预处理步骤来进行有效的训练。

3.1 准备训练数据集

将原始的 CSV 文件转换为 T5 模型可以接受的格式。T5 模型的训练通常需要将数据转换为特定的文本输入格式,通常是 “source_text -> target_text” 的形式。在这里,source_text 是我们希望模型转换的查询,而 target_text 是改写后的文本。

import pandas as pd
from datasets import Dataset, DatasetDict

def load_and_process_data(file_path: str) -> DatasetDict:
    """
    从 CSV 文件中加载和处理数据,并将其转换为适合 T5 训练的数据集格式。
    
    :param file_path: 包含查询和改写的 CSV 文件路径
    :return: 转换后的 DatasetDict 对象,包含训练和验证数据集
    """
    # 读取 CSV 文件
    df = pd.read_csv(file_path)
    
    # 处理数据,将其转换为适合 T5 训练的格式
    df['input_text'] = "rewrite: " + df['query']  # 添加前缀以指示 T5 模型执行重写任务
    df['target_text'] = df['rewrite']
    
    # 将 DataFrame 转换为 Hugging Face 的 Dataset 对象
    dataset = Dataset.from_pandas(df[['input_text', 'target_text']])
    
    # 将数据集分为训练和验证集
    train_test_split = dataset.train_test_split(test_size=0.1)
    
    return DatasetDict({
        'train': train_test_split['train'],
        'validation': train_test_split['test']
    })

# 加载数据并处理
dataset_dict = load_and_process_data('law_legal_data.csv')

训练数据集 (dataset_dict[‘train’]) 将包含以下格式的样本:

  • input_text: “rewrite: What is the penalty for breach of contract?”
  • target_text: “What are the consequences for breaking a contract?”

4. T5模型训练

4.1 训练T5模型

from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
#from datasets import Dataset, DatasetDict
from torch.utils.tensorboard import SummaryWriter
from transformers import TrainerCallback
import pandas as pd

# 加载数据并处理
#dataset_dict = load_and_process_data('law_legal_data.csv')

# 加载 T5 模型和分词器
model_name = "t5-large"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

# 数据预处理函数
def preprocess_function(examples):
    """
    预处理函数,将原始文本转换为模型所需的输入格式。
    
    :param examples: 包含查询和改写文本的字典
    :return: 包含模型输入和标签的字典
    """
    inputs = examples['input_text']
    targets = examples['target_text']
    
    # 将输入和标签文本转换为模型可接受的格式
    model_inputs = tokenizer(inputs, max_length=128, truncation=True, padding="max_length")
    labels = tokenizer(targets, max_length=128, truncation=True, padding="max_length")
    
    # 将标签添加到模型输入中
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

# 处理数据集
tokenized_datasets = dataset_dict.map(preprocess_function, batched=True)

# 设置训练参数
training_args = TrainingArguments(
    per_device_train_batch_size=4,                # 每个设备上的训练批次大小
    per_device_eval_batch_size=4,                 # 每个设备上的评估批次大小
    num_train_epochs=3,                           # 训练轮数
    logging_dir='./logs',                         # 日志目录
    logging_steps=10,                             # 日志记录步数
    save_steps=1000,                              # 模型保存步数
    output_dir='./results',                       # 模型保存目录
    evaluation_strategy="steps",                  # 评估策略,基于步数
    save_total_limit=3,                           # 保存的模型检查点数量限制
    learning_rate=2e-5,                           # 学习率
    warmup_steps=500,                             # 学习率预热步数
    weight_decay=0.01,                            # 权重衰减
    load_best_model_at_end=True,                  # 在训练结束时加载最好的模型
    metric_for_best_model="eval_loss",            # 选择评估指标来确定最好的模型
    greater_is_better=False                       # 更低的评估损失更好
)

# TensorBoard 回调类
class TensorBoardCallback(TrainerCallback):
    def __init__(self, log_dir):
        """
        初始化 TensorBoard 回调。
        
        :param log_dir: TensorBoard 日志目录
        """
        self.writer = SummaryWriter(log_dir)

    def on_log(self, args, state, control, logs=None, **kwargs):
        """
        在每个日志记录步骤中更新 TensorBoard。
        
        :param args: 训练参数
        :param state: 训练状态
        :param control: 控制参数
        :param logs: 日志信息
        """
        if logs is not None:
            for key, value in logs.items():
                self.writer.add_scalar(f"train/{key}", value, state.global_step)

# 创建 TensorBoard 回调实例
tensorboard_callback = TensorBoardCallback(log_dir="./tensorboard_logs")

# 创建 Trainer 实例
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
    callbacks=[tensorboard_callback]  # 添加 TensorBoard 回调
)

# 开始训练
trainer.train()

# 保存模型和分词器
model.save_pretrained("./t5-law-model")
tokenizer.save_pretrained("./t5-law-model")


5. DPO训练

5.0.0 csv 数据样例

sentence_a,sentence_b,preference
"What is the penalty for breach of contract?", "What are the consequences for breaking a contract?", 1
"How can I file for a divorce?", "What is the process for initiating a divorce?", 1
"How can I file for a divorce?", "What is the penalty for breach of contract?", 0

5.0.1 数据加载和处理:

import pandas as pd
from datasets import Dataset, DatasetDict

def load_and_process_dpo_data(file_path: str) -> DatasetDict:
    """
    从 CSV 文件中加载和处理 DPO 数据,并将其转换为适合 DPO 训练的数据集格式。
    
    :param file_path: 包含 A 句、B 句和偏好的 CSV 文件路径
    :return: 转换后的 DatasetDict 对象,包含训练和验证数据集
    """
    # 读取 CSV 文件
    df = pd.read_csv(file_path)
    
    # 处理数据,将其转换为适合 DPO 训练的格式
    df['input_text_a'] = df['sentence_a']
    df['input_text_b'] = df['sentence_b']
    df['labels'] = df['preference']
    
    # 将 DataFrame 转换为 Hugging Face 的 Dataset 对象
    dataset = Dataset.from_pandas(df[['input_text_a', 'input_text_b', 'labels']])
    
    # 将数据集分为训练和验证集
    train_test_split = dataset.train_test_split(test_size=0.1)
    
    return DatasetDict({
        'train': train_test_split['train'],
        'validation': train_test_split['test']
    })

# 加载数据并处理
dataset_dict = load_and_process_dpo_data('dpo_data.csv')

5.0.2 数据预处理:

from transformers import T5Tokenizer

tokenizer = T5Tokenizer.from_pretrained('t5-large')

def preprocess_dpo_function(examples):
    """
    预处理函数,将 A 句和 B 句转换为模型所需的输入格式,并创建标签。
    
    :param examples: 包含 A 句、B 句和偏好的字典
    :return: 包含模型输入和标签的字典
    """
    inputs_a = examples['input_text_a']
    inputs_b = examples['input_text_b']
    labels = examples['labels']
    
    # 编码 A 句和 B 句
    model_inputs_a = tokenizer(inputs_a, max_length=128, truncation=True, padding="max_length")
    model_inputs_b = tokenizer(inputs_b, max_length=128, truncation=True, padding="max_length")
    
    # 创建模型输入
    model_inputs = {
        'input_ids': model_inputs_a['input_ids'],
        'attention_mask': model_inputs_a['attention_mask'],
        'labels': model_inputs_b['input_ids'],
        'decoder_attention_mask': model_inputs_b['attention_mask']
    }
    
    # 将标签转换为浮点数
    model_inputs['labels'] = [float(label) for label in labels]
    
    return model_inputs

# 处理数据集
tokenized_datasets = dataset_dict.map(preprocess_dpo_function, batched=True)

5.1 使用DPO进行模型微调

from transformers import DPOConfig, DPOTrainer, TrainingArguments
from torch.utils.tensorboard import SummaryWriter
from transformers import TrainerCallback

# DPO配置
dpo_config = DPOConfig(
    model_name_or_path="./t5-law-model",  # 指定预训练模型的路径
    # 其他 DPO 配置参数可以在这里设置
)

# TensorBoard 回调类
class TensorBoardCallback(TrainerCallback):
    def __init__(self, log_dir):
        """
        初始化 TensorBoard 回调,用于记录训练过程中的日志信息。
        
        :param log_dir: TensorBoard 日志目录
        """
        self.writer = SummaryWriter(log_dir)  # 创建 TensorBoard SummaryWriter 实例

    def on_log(self, args, state, control, logs=None, **kwargs):
        """
        在每个日志记录步骤中更新 TensorBoard。
        
        :param args: 训练参数
        :param state: 训练状态
        :param control: 控制参数
        :param logs: 日志信息
        """
        if logs is not None:
            for key, value in logs.items():
                self.writer.add_scalar(f"train/{key}", value, state.global_step)
                # 记录训练过程中的指标到 TensorBoard

# 创建 TensorBoard 回调实例
tensorboard_callback = TensorBoardCallback(log_dir="./tensorboard_logs")

# 设置训练参数
training_args = TrainingArguments(
    per_device_train_batch_size=4,  # 每个设备上的训练批次大小
    per_device_eval_batch_size=4,   # 每个设备上的评估批次大小
    num_train_epochs=3,             # 训练的总轮次
    logging_dir='./logs',           # 日志保存目录
    logging_steps=10,               # 每多少步记录一次日志
    save_steps=1000,                # 每多少步保存一次模型
    output_dir='./results',         # 模型保存目录
    evaluation_strategy="steps",    # 评估策略,按步评估
    save_total_limit=3,             # 最多保存的模型数
    learning_rate=2e-5,             # 学习率
    warmup_steps=500,               # 预热步数
    weight_decay=0.01,              # 权重衰减
    load_best_model_at_end=True,    # 训练结束时加载最佳模型
    metric_for_best_model="eval_loss", # 用于选择最佳模型的指标
    greater_is_better=False         # 指标是否越大越好
)

# 创建 DPO 训练器
dpo_trainer = DPOTrainer(
    model=model,                     # 训练的模型
    args=training_args,              # 训练参数
    train_dataset=tokenized_datasets['train'],  # 训练数据集
    eval_dataset=tokenized_datasets['validation'],  # 验证数据集
    dpo_config=dpo_config,           # DPO配置
    callbacks=[tensorboard_callback] # 回调函数
)

# 开始 DPO 训练
dpo_trainer.train()

# 保存微调后的模型
model.save_pretrained("./t5-law-model-dpo")  # 保存模型到指定路径
tokenizer.save_pretrained("./t5-law-model-dpo")  # 保存分词器到指定路径      

  • 学习率调度:可以在 TrainingArguments 中配置 learning_rate 和 lr_scheduler_type。
  • 数据增强:增加更多的数据增强技术,如随机删除、插入等。
  • 早停策略:使用早停策略防止过拟合,可以设置 load_best_model_at_end=True。
  • 模型检查点:定期保存模型检查点,确保在训练中断时可以恢复。

6. TensorBoard监控

6.1 启动TensorBoard

tensorboard --logdir=./tensorboard_logs

7 使用DPO训练后的模型进行推理

from transformers import T5Tokenizer, T5ForConditionalGeneration

# 加载微调后的模型和分词器
model_path = "./t5-law-model-dpo"
tokenizer = T5Tokenizer.from_pretrained(model_path)  # 加载分词器
model = T5ForConditionalGeneration.from_pretrained(model_path)  # 加载模型

def preprocess_input(query: str, tokenizer) -> dict:
    """
    处理输入数据,以适应模型的输入要求。
    
    :param query: 要进行改写的查询文本
    :param tokenizer: 模型的分词器
    :return: 模型输入字典
    """
    input_text = f"rewrite: {query}"  # 添加任务前缀
    inputs = tokenizer(input_text, return_tensors="pt", max_length=128, truncation=True)
    return inputs

def generate_rewrite(query: str, model, tokenizer) -> str:
    """
    使用模型生成改写文本。
    
    :param query: 要进行改写的查询文本
    :param model: 微调后的模型
    :param tokenizer: 模型的分词器
    :return: 生成的改写文本
    """
    inputs = preprocess_input(query, tokenizer)  # 准备输入
    outputs = model.generate(
        input_ids=inputs['input_ids'], 
        attention_mask=inputs['attention_mask'],
        max_length=128,                # 生成文本的最大长度
        num_beams=5,                   # 使用束搜索以提高生成质量
        early_stopping=True            # 提前停止以提高效率
    )
    rewrite = tokenizer.decode(outputs[0], skip_special_tokens=True)  # 解码生成的文本
    return rewrite

# 示例查询
query = "What is the penalty for breach of contract?"

# 生成改写
rewrite = generate_rewrite(query, model, tokenizer)
print(f"Original Query: {query}")
print(f"Generated Rewrite: {rewrite}")

7.1 vllm 部署推理

pip install vllm

单次推理

import torch
from transformers import T5Tokenizer
from vllm import LLM, CompletionConfig

# 加载模型和分词器
model_path = "./t5-law-model-dpo"
tokenizer = T5Tokenizer.from_pretrained(model_path)  # 加载分词器

# 初始化vllm模型
llm = LLM.from_pretrained(model_path)

def preprocess_input(query: str) -> str:
    """
    处理输入数据以适应模型的输入要求。
    
    :param query: 要进行改写的查询文本
    :return: 处理后的输入文本
    """
    input_text = f"rewrite: {query}"  # 添加任务前缀
    return input_text

def generate_rewrite_vllm(query: str, llm) -> str:
    """
    使用vllm进行加速推理。
    
    :param query: 要进行改写的查询文本
    :param llm: vllm模型
    :return: 生成的改写文本
    """
    input_text = preprocess_input(query)  # 准备输入文本
    completion_config = CompletionConfig(
        max_length=128,           # 生成文本的最大长度
        num_beams=5,              # 使用束搜索
        early_stopping=True,      # 提前停止
        temperature=0.7,          # 控制生成的随机性
        top_k=50,                 # 限制生成的词汇数目
        top_p=0.95                # 控制生成的词汇概率分布
    )
    # 生成文本
    response = llm.complete(prompt=input_text, config=completion_config)
    rewrite = response['text']
    return rewrite

# 示例查询
query = "What is the penalty for breach of contract?"

# 生成改写
rewrite = generate_rewrite_vllm(query, llm)
print(f"Original Query: {query}")
print(f"Generated Rewrite: {rewrite}")

批量推理

import torch
from transformers import T5Tokenizer
from vllm import LLM, CompletionConfig

# 加载模型和分词器
model_path = "./t5-law-model-dpo"
tokenizer = T5Tokenizer.from_pretrained(model_path)  # 加载分词器

# 初始化vllm模型
llm = LLM.from_pretrained(model_path)

def preprocess_batch_inputs(queries: List[str]) -> List[str]:
    """
    处理批量输入数据以适应模型的输入要求。
    
    :param queries: 要进行改写的查询文本列表
    :return: 处理后的输入文本列表
    """
    return [f"rewrite: {query}" for query in queries]  # 添加任务前缀

def generate_rewrites_batch(queries: List[str], llm) -> List[str]:
    """
    使用vllm进行批量推理。
    
    :param queries: 要进行改写的查询文本列表
    :param llm: vllm模型
    :return: 生成的改写文本列表
    """
    input_texts = preprocess_batch_inputs(queries)  # 准备输入文本
    rewrites = []
    completion_config = CompletionConfig(
        max_length=128,           # 生成文本的最大长度
        num_beams=5,              # 使用束搜索
        early_stopping=True,      # 提前停止
        temperature=0.7,          # 控制生成的随机性
        top_k=50,                 # 限制生成的词汇数目
        top_p=0.95                # 控制生成的词汇概率分布
    )
    for input_text in input_texts:
        # 生成文本
        response = llm.complete(prompt=input_text, config=completion_config)
        rewrite = response['text']
        rewrites.append(rewrite)
    return rewrites

# 示例查询列表
queries = [
    "What is the penalty for breach of contract?",
    "How can I file for a divorce?"
]

# 批量生成改写
rewrites = generate_rewrites_batch(queries, llm)
for query, rewrite in zip(queries, rewrites):
    print(f"Original Query: {query}")
    print(f"Generated Rewrite: {rewrite}")

其他优化建议

1,并行计算:
  • 使用 torch.nn.DataParallel 或 torch.nn.parallel.DistributedDataParallel 进行模型的多GPU推理,以提高计算速度。
2,模型量化:
  • 使用量化技术减少模型的存储需求和计算复杂度,从而加速推理。可以使用 torch.quantization 进行模型量化。
3,模型剪枝:
  • 对模型进行剪枝以减少计算量和提高推理速度。剪枝可以通过 transformers 库中的一些工具或自定义实现来完成。
4,优化生成参数:
  • 通过调整生成参数(如 temperature、top_k、top_p)来控制生成文本的质量和多样性。通常情况下,较低的 temperature 和较小的 top_k 值会生成更为确定性的文本,而较高的 temperature 和较大的 top_k 值则会生成更多样化的文本。
  • 9
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值