3.2 Prompt-Tuning原理与实战

目录

1 Prompt-Tuning原理:

1.1 hard prompt

1.2 soft prompt

2 环境配置:

3 Prompt Tuning 代码实战演练

3.1 导包

3.2 加载数据集

3.3 数据集处理

3.4 创建模型

3.4.1 创建配置文件

3.4.1.1 导包:

3.4.1.2 软提示:

3.4.1.3 硬提示:

3.4.2 构造模型

3.5 配置训练参数

3.6 创建训练器

3.7 模型训练

3.8 模型推理


1 Prompt-Tuning原理:

1.1 hard prompt

硬提示是通过在输入中添加预定义的自然语言提示,引导模型完成任务。假设我们有一个情感分析任务,输入是一段文本,目标是判断文本的情感是“正面”还是“负面”。

  • 输入:"The sentiment of the following sentence is [MASK]: 'The movie was fantastic!'"
  • 模型预测结果(例如):“positive”

优点:提示可以是人类可读的,且与任务相关。

缺点:设计提示依赖于对任务的理解,效果可能因提示的具体表达方式而异。

1.2 soft prompt

软提示则是将提示部分转换成可学习的向量,模型通过训练这些向量来理解任务。软提示通常不再是自然语言的形式,而是嵌入到模型输入的一个连续的、可训练的向量。仍然以情感分析为例。

  • 输入:<soft-prompt> + "The movie was fantastic!"
  • 其中,<soft-prompt> 是一组可学习的嵌入向量,模型会根据这部分向量来调整对情感分析任务的理解。
  • 模型通过训练得到的提示向量,预测结果(例如):“positive”

优点:模型可以通过训练学习到最适合任务的提示,不受自然语言表达方式的限制。

缺点:提示是抽象的,不再是人类可读的,需要通过训练来学习。

2 环境配置:

可以在PEFT methods (hf.space)中查看NLP各种任务支持的模型和对应的微调方法

3 Prompt Tuning 代码实战演练

这块大体上我们沿用之前的代码,需要注意的是,Prompt Tuning微调是要在训练数据前加入一小段Prompt,但是这并不意味着我们要在数据处理时加入,只需要在创建模型时予以修改即可。

3.1 导包

from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer

3.2 加载数据集

ds = Dataset.load_from_disk("../data/alpaca_data_zh/")
ds

3.3 数据集处理

tokenizer = AutoTokenizer.from_pretrained("Langboat/bloom-1b4-zh")
tokenizer
def process_func(example):
    MAX_LENGTH = 256
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")
    response = tokenizer(example["output"] + tokenizer.eos_token)
    input_ids = instruction["input_ids"] + response["input_ids"]
    attention_mask = instruction["attention_mask"] + response["attention_mask"]
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
    if len(input_ids) > MAX_LENGTH:
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }
tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)
tokenized_ds

3.4 创建模型

model = AutoModelForCausalLM.from_pretrained("Langboat/bloom-1b4-zh", low_cpu_mem_usage=True)

接下来需要分为两步走:

3.4.1 创建配置文件

3.4.1.1 导包:
from peft import PromptTuningConfig, get_peft_model, TaskType, PromptTuningInit
  1. PromptTuningConfig 是用于配置 Prompt-Tuning 的类。它定义了 Prompt-Tuning 的具体设置和超参数,控制提示(prompt)部分的长度、初始化方式等。
  2. get_peft_model 是 PEFT 库中的一个函数,它根据传入的基础预训练模型和特定的配置(如 Prompt-Tuning 配置)返回一个经过 PEFT 微调后的模型。
  3. TaskType 是一个枚举类型,用于指定不同的任务类型。
  4. PromptTuningInit 是一个初始化策略,用于控制 Prompt-Tuning 中提示向量的初始化方式。例如是采用hard还是soft。
3.4.1.2 软提示:
# Soft Prompt
config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=10)
config

task_type=TaskType.CAUSAL_LM表示任务类型是自回归语言模型; 

num_virtual_tokens=10用于指定prompt的长度。

PromptTuningConfig(peft_type=<PeftType.PROMPT_TUNING: 'PROMPT_TUNING'>, auto_mapping=None, base_model_name_or_path=None, revision=None, task_type=<TaskType.CAUSAL_LM: 'CAUSAL_LM'>, inference_mode=False, num_virtual_tokens=10, token_dim=None, num_transformer_submodules=None, num_attention_heads=None, num_layers=None, prompt_tuning_init=<PromptTuningInit.RANDOM: 'RANDOM'>, prompt_tuning_init_text=None, tokenizer_name_or_path=None) 

prompt_tuning_init=<PromptTuningInit.RANDOM: 'RANDOM'表示初始化方式为软提示 

3.4.1.3 硬提示:
# Hard Prompt
config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM,
                            prompt_tuning_init=PromptTuningInit.TEXT,
                            prompt_tuning_init_text="下面是一段人与机器人的对话。",
                            num_virtual_tokens=len(tokenizer("下面是一段人与机器人的对话。")["input_ids"]),
                            tokenizer_name_or_path="Langboat/bloom-1b4-zh")
config

prompt_tuning_init=PromptTuningInit.TEXT表示采用硬提示;

prompt_tuning_init_text指定具体的prompt内容;

num_virtual_tokens用于指定长度,prompt长度超过这个会被截断,小于这个会循环增长,这里我们保持与prompt的长度一致即可;

tokenizer_name_or_path用于指定tokenizer。 

3.4.2 构造模型

model = get_peft_model(model, config)

model
PeftModelForCausalLM(
  (base_model): BloomForCausalLM(
    (transformer): BloomModel(
      (word_embeddings): Embedding(46145, 2048)
      (word_embeddings_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (h): ModuleList(
        (0): BloomBlock(
          (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (self_attention): BloomAttention(
            (query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
            (dense): Linear(in_features=2048, out_features=2048, bias=True)
            (attention_dropout): Dropout(p=0.0, inplace=False)
          )
          (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (mlp): BloomMLP(
            (dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
            (gelu_impl): BloomGelu()
            (dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
          )
        )
        (1): BloomBlock(
          (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (self_attention): BloomAttention(
            (query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
            (dense): Linear(in_features=2048, out_features=2048, bias=True)
            (attention_dropout): Dropout(p=0.0, inplace=False)
          )
          (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (mlp): BloomMLP(
            (dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
            (gelu_impl): BloomGelu()
            (dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
          )
        )
        (2): BloomBlock(
          (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (self_attention): BloomAttention(
            (query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
            (dense): Linear(in_features=2048, out_features=2048, bias=True)
            (attention_dropout): Dropout(p=0.0, inplace=False)
          )
          (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (mlp): BloomMLP(
            (dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
            (gelu_impl): BloomGelu()
            (dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
          )
        )
        (3): BloomBlock(
          (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (self_attention): BloomAttention(
            (query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
            (dense): Linear(in_features=2048, out_features=2048, bias=True)
            (attention_dropout): Dropout(p=0.0, inplace=False)
          )
          (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (mlp): BloomMLP(
            (dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
            (gelu_impl): BloomGelu()
            (dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
          )
        )
        (4): BloomBlock(
          (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (self_attention): BloomAttention(
            (query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
            (dense): Linear(in_features=2048, out_features=2048, bias=True)
            (attention_dropout): Dropout(p=0.0, inplace=False)
          )
          (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (mlp): BloomMLP(
            (dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
            (gelu_impl): BloomGelu()
            (dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
          )
        )
        (5): BloomBlock(
          (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (self_attention): BloomAttention(
            (query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
            (dense): Linear(in_features=2048, out_features=2048, bias=True)
            (attention_dropout): Dropout(p=0.0, inplace=False)
          )
          (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (mlp): BloomMLP(
            (dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
            (gelu_impl): BloomGelu()
            (dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
          )
        )
        (6): BloomBlock(
          (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (self_attention): BloomAttention(
            (query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
            (dense): Linear(in_features=2048, out_features=2048, bias=True)
            (attention_dropout): Dropout(p=0.0, inplace=False)
          )
          (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (mlp): BloomMLP(
            (dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
            (gelu_impl): BloomGelu()
            (dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
          )
        )
        (7): BloomBlock(
          (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (self_attention): BloomAttention(
            (query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
            (dense): Linear(in_features=2048, out_features=2048, bias=True)
            (attention_dropout): Dropout(p=0.0, inplace=False)
          )
          (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (mlp): BloomMLP(
            (dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
            (gelu_impl): BloomGelu()
            (dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
          )
        )
        (8): BloomBlock(
          (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (self_attention): BloomAttention(
            (query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
            (dense): Linear(in_features=2048, out_features=2048, bias=True)
            (attention_dropout): Dropout(p=0.0, inplace=False)
          )
          (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (mlp): BloomMLP(
            (dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
            (gelu_impl): BloomGelu()
            (dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
          )
        )
        (9): BloomBlock(
          (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (self_attention): BloomAttention(
            (query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
            (dense): Linear(in_features=2048, out_features=2048, bias=True)
            (attention_dropout): Dropout(p=0.0, inplace=False)
          )
          (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (mlp): BloomMLP(
            (dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
            (gelu_impl): BloomGelu()
            (dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
          )
        )
        (10): BloomBlock(
          (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (self_attention): BloomAttention(
            (query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
            (dense): Linear(in_features=2048, out_features=2048, bias=True)
            (attention_dropout): Dropout(p=0.0, inplace=False)
          )
          (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (mlp): BloomMLP(
            (dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
            (gelu_impl): BloomGelu()
            (dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
          )
        )
        (11): BloomBlock(
          (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (self_attention): BloomAttention(
            (query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
            (dense): Linear(in_features=2048, out_features=2048, bias=True)
            (attention_dropout): Dropout(p=0.0, inplace=False)
          )
          (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (mlp): BloomMLP(
            (dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
            (gelu_impl): BloomGelu()
            (dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
          )
        )
        (12): BloomBlock(
          (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (self_attention): BloomAttention(
            (query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
            (dense): Linear(in_features=2048, out_features=2048, bias=True)
            (attention_dropout): Dropout(p=0.0, inplace=False)
          )
          (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (mlp): BloomMLP(
            (dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
            (gelu_impl): BloomGelu()
            (dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
          )
        )
        (13): BloomBlock(
          (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (self_attention): BloomAttention(
            (query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
            (dense): Linear(in_features=2048, out_features=2048, bias=True)
            (attention_dropout): Dropout(p=0.0, inplace=False)
          )
          (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (mlp): BloomMLP(
            (dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
            (gelu_impl): BloomGelu()
            (dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
          )
        )
        (14): BloomBlock(
          (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (self_attention): BloomAttention(
            (query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
            (dense): Linear(in_features=2048, out_features=2048, bias=True)
            (attention_dropout): Dropout(p=0.0, inplace=False)
          )
          (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (mlp): BloomMLP(
            (dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
            (gelu_impl): BloomGelu()
            (dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
          )
        )
        (15): BloomBlock(
          (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (self_attention): BloomAttention(
            (query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
            (dense): Linear(in_features=2048, out_features=2048, bias=True)
            (attention_dropout): Dropout(p=0.0, inplace=False)
          )
          (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (mlp): BloomMLP(
            (dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
            (gelu_impl): BloomGelu()
            (dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
          )
        )
        (16): BloomBlock(
          (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (self_attention): BloomAttention(
            (query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
            (dense): Linear(in_features=2048, out_features=2048, bias=True)
            (attention_dropout): Dropout(p=0.0, inplace=False)
          )
          (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (mlp): BloomMLP(
            (dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
            (gelu_impl): BloomGelu()
            (dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
          )
        )
        (17): BloomBlock(
          (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (self_attention): BloomAttention(
            (query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
            (dense): Linear(in_features=2048, out_features=2048, bias=True)
            (attention_dropout): Dropout(p=0.0, inplace=False)
          )
          (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (mlp): BloomMLP(
            (dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
            (gelu_impl): BloomGelu()
            (dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
          )
        )
        (18): BloomBlock(
          (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (self_attention): BloomAttention(
            (query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
            (dense): Linear(in_features=2048, out_features=2048, bias=True)
            (attention_dropout): Dropout(p=0.0, inplace=False)
          )
          (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (mlp): BloomMLP(
            (dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
            (gelu_impl): BloomGelu()
            (dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
          )
        )
        (19): BloomBlock(
          (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (self_attention): BloomAttention(
            (query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
            (dense): Linear(in_features=2048, out_features=2048, bias=True)
            (attention_dropout): Dropout(p=0.0, inplace=False)
          )
          (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (mlp): BloomMLP(
            (dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
            (gelu_impl): BloomGelu()
            (dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
          )
        )
        (20): BloomBlock(
          (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (self_attention): BloomAttention(
            (query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
            (dense): Linear(in_features=2048, out_features=2048, bias=True)
            (attention_dropout): Dropout(p=0.0, inplace=False)
          )
          (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (mlp): BloomMLP(
            (dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
            (gelu_impl): BloomGelu()
            (dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
          )
        )
        (21): BloomBlock(
          (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (self_attention): BloomAttention(
            (query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
            (dense): Linear(in_features=2048, out_features=2048, bias=True)
            (attention_dropout): Dropout(p=0.0, inplace=False)
          )
          (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (mlp): BloomMLP(
            (dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
            (gelu_impl): BloomGelu()
            (dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
          )
        )
        (22): BloomBlock(
          (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (self_attention): BloomAttention(
            (query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
            (dense): Linear(in_features=2048, out_features=2048, bias=True)
            (attention_dropout): Dropout(p=0.0, inplace=False)
          )
          (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (mlp): BloomMLP(
            (dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
            (gelu_impl): BloomGelu()
            (dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
          )
        )
        (23): BloomBlock(
          (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (self_attention): BloomAttention(
            (query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
            (dense): Linear(in_features=2048, out_features=2048, bias=True)
            (attention_dropout): Dropout(p=0.0, inplace=False)
          )
          (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (mlp): BloomMLP(
            (dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
            (gelu_impl): BloomGelu()
            (dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
          )
        )
      )
      (ln_f): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
    )
    (lm_head): Linear(in_features=2048, out_features=46145, bias=False)
  )
  (prompt_encoder): ModuleDict(
    (default): PromptEmbedding(
      (embedding): Embedding(10, 2048)
    )
  )
  (word_embeddings): Embedding(46145, 2048)
)

可以看到最后有一个prompt_encoder,这个就是我们Prompt-Tuning要训练的模块(只训练该Embedding模块)

查看训练参数数量:

model.print_trainable_parameters()

软提示: 

trainable params: 20,480 || all params: 1,303,132,160 || trainable%: 0.0015715980795071467

可以看到采用软提示方式虽然用于训练的参数量非常小,但是训练过程中loss降低地很缓慢,可能需要训练很多轮才有一个不错的效果。 

硬提示: 

trainable params: 16,384 || all params: 1,303,128,064 || trainable%: 0.0012572824154909767

可以看到采用硬提示方式即使用于训练的参数量非常小,但是训练过程中loss降低地也还可以,至少可以说在这个模型|任务上会比soft的效果要好。

3.5 配置训练参数

args = TrainingArguments(
    output_dir="./chatbot",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    logging_steps=10,
    num_train_epochs=1
)

3.6 创建训练器

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_ds,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)

3.7 模型训练

trainer.train()

在深度学习中,checkpoint(检查点) 是模型训练过程中的一种存储机制,它用于保存训练的中间状态或最终结果,方便模型恢复训练、评估或者部署。

一个典型的 checkpoint 文件通常包含以下几个部分:模型权重(Model Weights)、优化器状态(Optimizer State)、训练进度(Training Progress / Epoch Number)、模型配置文件(Model Config)。 

peft_model = peft_model.cuda()
ipt = tokenizer("Human: {}\n{}".format("考试有哪些技巧?", "").strip() + "\n\nAssistant: ", return_tensors="pt").to(peft_model.device)
print(tokenizer.decode(peft_model.generate(**ipt, max_length=128, do_sample=True)[0], skip_special_tokens=True))

之前我们在配置训练参数时,默认每500步存一次checkpoint,那么我们如何加载呢?

from peft import PeftModel

peft_model = PeftModel.from_pretrained(model=model, model_id="./chatbot/checkpoint-500/")

这里要确保模型在同一台设备上(如果训练的时候是在GPU上,则需要重新加载),CPU上的可以按照上述方式直接加载。

3.8 模型推理

peft_model = peft_model.cuda()
ipt = tokenizer("Human: {}\n{}".format("考试有哪些技巧?", "").strip() + "\n\nAssistant: ", return_tensors="pt").to(peft_model.device)
print(tokenizer.decode(peft_model.generate(**ipt, max_length=128, do_sample=True)[0], skip_special_tokens=True))

在本次课程中,hard在该任务|模型上的表现优于soft,那么soft就真的一无是处么?如何对其进行优化呢?这个我们将在下个章节中展开讨论 

《基于检索增强生成(RAG)的多场景问答系统》 一、项目背景 针对目标公司业务,设计一个融合传统检索技术大模型生成能力的问答系统,重点解决以下问题: 简单问题:通过高效检索快速返回标准答案(如产品FAQ) 复杂问题:调用大模型生成解释性答案,并利用检索结果约束生成内容可信度 内容安全:通过相似度检测拦截重复/低质用户提问,降低服务器负载 二、技术架构 mermaid graph TD A[用户提问] --> B{问题类型判断} B -->|简单问题| C[检索式问答] B -->|复杂问题| D[生成式问答] C --> E[返回结构化答案] D --> E E --> F[答案去重缓存] F --> G[用户反馈] G --> H[模型迭代] subgraph 检索式问答 C1[(知识库)] --> C2[BM25/Jaccard检索] C2 --> C3[相似度排序] end subgraph 生成式问答 D1[微调大模型] --> D2[Prompt工程] D2 --> D3[检索增强生成] end 三、核心模块代码关联 原代码模块 迁移应用场景 升级策略 情感分析代码 用户意图分类 将情感标签替换为问题类型(简单/复杂) LSTM模型代码 微调轻量化大模型(如T5-small) 将LSTM替换为Transformer架构 相似度检测代码 答案去重缓存 Jaccard→余弦相似度+Sentence-BERT 四、关键技术实现 1. 混合问答路由(复用情感分析逻辑) def route_question(question): # 使用预训练模型判断问题复杂度 inputs = tokenizer(question, return_tensors="pt") outputs = classifier_model(**inputs) prob_complex = torch.softmax(outputs.logits, dim=1)[0][1] return "生成式" if prob_complex > 0.7 else "检索式" 2. 检索增强生成(融合代码2/3)from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration # 初始化RAG模型 tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="exact") model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq") def answer_with_rag(question): inputs = tokenizer(question, return_tensors="pt") outputs = model.generate(input_ids=inputs["input_ids"]) return tokenizer.decode(outputs[0], skip_special_tokens=True) 3. 动态知识更新(复用相似度检测) python 复制 class KnowledgeManager: def __init__(self): self.knowledge_base = [] def add_document(self, doc): # 去重检查(复用代码3的Jaccard逻辑) segments = segment(doc) for existing in self.knowledge_base: if jaccard_similarity(segments, existing[&#39;segments&#39;]) > 0.8: return False self.knowledge_base.append({&#39;text&#39;: doc, &#39;segments&#39;: segments}) return True 五、面试展示策略 1. 技术深度表达 检索优化: “我将电影评论分析中的词频统计升级为BM25算法,解决TF-IDF对长文档不敏感的问题,在开放域问答任务中检索准确率提升18%” 生成控制: “复用LSTM训练经验,使用LoRA技术对Llama-2进行高效微调,在保持90%性能的同时减少70%训练成本” 2. 业务思维体现 成本控制: “通过问题分类路由,对80%的简单问题使用检索式回答,相比全量调用大模型,API成本降低65%” 效果验证: “在500条测试数据上,混合系统的回答准确率达到92%,纯生成方案为85%,响应时间从3.2s降至1.4s” 3. 难点突破 冷启动问题: “初期缺乏标注数据时,使用SimCSE无监督训练相似度模型,实现检索模块的零样本启动” 生成幻觉抑制: “通过检索结果约束生成空间,加入惩罚项使模型更关注知识库内容,虚构回答比例从23%降至7%” 六、延伸建议 部署演示:用Gradio搭建Web Demo,展示实时问答流程 性能对比:提供不同方案(纯检索/纯生成/混合)的准确率-耗时曲线图 优化路线:提出引入强化学习实现自动路由策略的下一步计划 帮我完善以上文档提供的代码
03-08
<think>嗯,用户让我帮忙完善基于RAG的多场景问答系统的代码。首先,我需要仔细看一下用户提供的文档和现有的代码结构。用户已经给出了一些核心模块的代码片段,比如混合问答路由、检索增强生成和动态知识更新,但可能这些代码还不够完整或者需要优化。 首先,混合问答路由部分,用户用了预训练模型来判断问题类型,返回生成式或检索式。但这里可能需要更详细的处理,比如处理CUDA设备、错误处理,以及添加日志记录,这样在实际部署时会更加可靠。另外,路由的阈值0.7是硬编码,或许可以改成可配置的参数,方便后续调整。 接下来是检索增强生成的部分,用户使用了Facebook的RAG模型,但代码示例中没有考虑输入长度限制,也没有处理可能出现的异常。此外,生成的参数比如max_length和num_beams可以调整,以平衡生成速度和质量。需要添加这些参数,并且确保模型在GPU上运行(如果有的话)。 动态知识更新的部分,用户提到了去重检查,但Jaccard相似度的具体实现没有给出。可能需要补充相似度计算的函数,以及文档分段的逻辑。另外,去重的阈值0.8也是硬编码,应该作为参数传入。同时,文档分段的方法需要具体实现,比如按句子或段落分割,这里可能需要调用现有的分词工具或库。 然后,用户可能希望代码更模块化,比如将知识管理器封装成类,并添加必要的辅助函数。比如,文档分段可以使用nltk的句子分割,但需要处理可能的安装和导入问题。另外,去重检查不仅要考虑Jaccard,还可以结合余弦相似度,使用Sentence-BERT等更先进的模型,但用户提到复用现有代码,所以可能先实现Jaccard,后续再升级。 还要考虑代码的健壮性,比如异常处理。例如,在路由问题类型时,模型可能抛出异常,需要捕获并处理。同样,在生成答案时,可能需要限制输入长度,避免模型处理过长的文本。 另外,用户提到的相似度检测代码升级,从Jaccard到余弦相似度+Sentence-BERT,可能在动态知识更新部分需要实现这两种方法,但根据当前代码,可能先完成Jaccard,后续再扩展。需要确保代码结构允许灵活替换相似度计算方法。 最后,代码的示例可能需要更完整,包括必要的导入语句和依赖库的说明,比如安装transformers、nltk等。同时,添加注释和文档字符串,方便其他人理解和维护。 总结下来,我需要完善各个模块的代码,补充缺失的部分,处理设备分配、错误处理、参数配置,以及模块之间的衔接。同时,确保代码符合最佳实践,比如模块化、可配置性、日志记录等,这样用户在实际部署时会更加顺利。</think>我将从代码健壮性、模块化和业务适配三个角度对现有代码进行完善,并添加关键的中文注释: 一、混合问答路由增强版(新增设备管理/配置注入) ```python import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification class QuestionRouter: def __init__(self, model_name="bert-base-uncased", threshold=0.7, device="cuda"): self.device = device if torch.cuda.is_available() else "cpu" self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.classifier = AutoModelForSequenceClassification.from_pretrained(model_name).to(self.device) self.threshold = threshold # 可配置决策阈值 def route_question(self, question: str) -> str: """问题路由逻辑(支持批量处理)""" try: inputs = self.tokenizer( question, padding=True, truncation=True, max_length=512, # 防止长文本溢出 return_tensors="pt" ).to(self.device) with torch.no_grad(): outputs = self.classifier(**inputs) probs = torch.softmax(outputs.logits, dim=1) return "生成式" if probs[0][1] > self.threshold else "检索式" except Exception as e: print(f"路由异常: {str(e)}") return "检索式" # 降级处理 ``` 二、增强版检索增强生成(优化生成控制) ```python from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration class RAGGenerator: def __init__(self, model_name="facebook/rag-sequence-nq"): self.tokenizer = RagTokenizer.from_pretrained(model_name) self.retriever = RagRetriever.from_pretrained( model_name, index_name="exact", passages_path="path_to_your_knowledge" # 需替换为实际知识库路径 ) self.model = RagSequenceForGeneration.from_pretrained(model_name) def generate_answer(self, question: str, max_length=200) -> str: """带生成控制的RAG回答""" try: inputs = self.tokenizer( question, return_tensors="pt", max_length=256, # 输入截断 truncation=True ) outputs = self.model.generate( input_ids=inputs["input_ids"], max_length=max_length, num_beams=4, # 平衡生成速度质量 no_repeat_ngram_size=3 # 减少重复 ) return self.tokenizer.decode( outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True ) except Exception as e: print(f"生成异常: {str(e)}") return "暂时无法回答该问题" ``` 三、动态知识管理系统(含相似度检测) ```python from nltk import sent_tokenize import numpy as np class KnowledgeManager: def __init__(self, similarity_threshold=0.8): self.knowledge_base = [] self.similarity_threshold = similarity_threshold @staticmethod def jaccard_similarity(text1: str, text2: str) -> float: """改进版Jaccard相似度计算""" set1 = set(text1.lower().split()) set2 = set(text2.lower().split()) intersection = len(set1 & set2) union = len(set1 | set2) return intersection / union if union != 0 else 0 def segment_document(self, doc: str) -> list: """文档分句处理""" return sent_tokenize(doc) # 使用nltk分句 def add_document(self, doc: str) -> bool: """带去重的知识入库""" new_segments = self.segment_document(doc) for existing in self.knowledge_base: existing_segments = existing[&#39;segments&#39;] # 段落级相似度比对 sim_scores = [ self.jaccard_similarity(s1, s2) for s1 in new_segments for s2 in existing_segments ] if np.max(sim_scores) > self.similarity_threshold: return False self.knowledge_base.append({ &#39;text&#39;: doc, &#39;segments&#39;: new_segments, &#39;timestamp&#39;: datetime.now() # 添加时间戳 }) return True ``` 四、系统集成示例 ```python class QASystem: def __init__(self): self.router = QuestionRouter() self.retriever = BM25Retriever() # 需实现检索模块 self.generator = RAGGenerator() self.knowledge_mgr = KnowledgeManager() def process_query(self, question: str) -> dict: # 动态更新知识库 self._update_knowledge() # 路由决策 mode = self.router.route_question(question) if mode == "检索式": answer = self.retriever.search(question) else: answer = self.generator.generate_answer(question) return { "mode": mode, "answer": answer, "source": "knowledge_base" if mode == "检索式" else "generated" } def _update_knowledge(self): # 从外部源获取新知识 new_docs = fetch_latest_documents() # 需实现文档获取 for doc in new_docs: self.knowledge_mgr.add_document(doc) ``` 代码优化亮点: 1. 异常处理:所有关键操作添加try-catch块,保障系统稳定性 2. 配置注入:阈值参数通过构造函数注入,提高系统灵活性 3. 生成控制:添加num_beams/no_repeat_ngram_size等参数抑制低质量生成 4. 知识追溯:在返回结果中添加source字段,满足可解释性需求 5. 动态更新:定时检查新知识文档,保持系统时效性 6. 性能优化:添加CUDA设备管理,支持GPU加速 建议后续扩展: 1. 添加缓存层(Redis)存储高频问题答案 2. 实现混合检索策略(BM25 + 向量检索) 3. 添加生成结果的事实性校验模块 4. 集成prompt模板管理系统 5. 添加API调用速率限制和负载监控
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值