关注B站可以观看更多实战教学视频:hallo128的个人空间
【LLM】吴恩达『微调大模型』代码笔记(03_Instruction_tuning_lab_student)
指令微调(代码详解-代码及输出结果)
1. 推断函数定义
# 通过 AutoTokenizer 从预训练模型 EleutherAI/pythia-70m 加载分词器(tokenizer)。
# 通过 AutoModelForCausalLM 从同样的预训练模型 EleutherAI/pythia-70m 加载因果语言模型(Causal LM),该模型用于生成文本
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m")
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-70m")
# 定义一个函数 inference,用于进行推理。
# 参数包括
# 输入文本 text、模型 model、分词器 tokenizer,以及两个可选参数:
# max_input_tokens 限制输入最大 token 数,max_output_tokens 限制输出最大 token 数
def inference(text, model, tokenizer, max_input_tokens=1000, max_output_tokens=100):
# 分词:通过分词器 tokenizer 将输入文本 text 转换为 token 序列。
# 使用 encode 函数将其转换为 PyTorch 的张量格式(return_tensors="pt"),并对输入进行截断(truncation=True),保证输入的 token 数量不超过 max_input_tokens。
# Tokenize
input_ids = tokenizer.encode(
text,
return_tensors="pt",
truncation=True,
max_length=max_input_tokens
)
# 生成:获取模型所在设备(model.device),并使用模型生成文本。
# generate 函数基于输入 token 序列 input_ids 生成文本,生成的 token 总数不超过 max_output_tokens
# Generate
device = model.device
generated_tokens_with_prompt = model.generate(
input_ids=input_ids.to(device),
max_length=max_output_tokens
)
# 解码:通过分词器 tokenizer 将生成的 token 序列转回可读的文本。batch_decode 函数将生成的 token 解码为字符串,并跳过特殊 token(如 <pad> 等)
# Decode
generated_text_with_prompt = tokenizer.batch_decode(generated_tokens_with_prompt, skip_special_tokens=True)
# 剥离输入提示:由于生成的文本包含了输入文本(prompt),所以通过截取去除输入文本部分,仅保留生成的输出文本
# Strip the prompt
generated_text_answer = generated_text_with_prompt[0][len(text):]
# 返回生成的文本:函数返回生成的文本结果(不包含输入提示的部分)
return generated_text_answer
2. 加载数据集
# 加载微调数据集:通过 load_dataset 函数加载指定路径 "lamini/lamini_docs" 下的微调数据集 finetuning_dataset,并打印出来
finetuning_dataset_path = "lamini/lamini_docs"
finetuning_dataset = load_dataset(finetuning_dataset_path)
print(finetuning_dataset)
DatasetDict({
train: Dataset({
features: [‘question’, ‘answer’, ‘input_ids’, ‘attention_mask’, ‘labels’],
num_rows: 1260
})
test: Dataset({
features: [‘question’, ‘answer’, ‘input_ids’, ‘attention_mask’, ‘labels’],
num_rows: 140
})
})
# 获取数据集中的一个测试样本:从测试数据集中获取第一个样本 test_sample,并打印该样本内容。
test_sample = finetuning_dataset["test"][0]
print(test_sample)
{‘question’: ‘Can Lamini generate technical documentation or user manuals for software projects?’, ‘answer’: ‘Yes, Lamini can generate technical documentation and user manuals for software projects. It uses natural language generation techniques to create clear and concise documentation that is easy to understand for both technical and non-technical users. This can save developers a significant amount of time and effort in creating documentation, allowing them to focus on other aspects of their projects.’, ‘input_ids’: [5804, 418, 4988, 74, 6635, 7681, 10097, 390, 2608, 11595, 84, 323, 3694, 6493, 32, 4374, 13, 418, 4988, 74, 476, 6635, 7681, 10097, 285, 2608, 11595, 84, 323, 3694, 6493, 15, 733, 4648, 3626, 3448, 5978, 5609, 281, 2794, 2590, 285, 44003, 10097, 326, 310, 3477, 281, 2096, 323, 1097, 7681, 285, 1327, 14, 48746, 4212, 15, 831, 476, 5321, 12259, 247, 1534, 2408, 273, 673, 285, 3434, 275, 6153, 10097, 13, 6941, 731, 281, 2770, 327, 643, 7794, 273, 616, 6493, 15], ‘attention_mask’: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], ‘labels’: [5804, 418, 4988, 74, 6635, 7681, 10097, 390, 2608, 11595, 84, 323, 3694, 6493, 32, 4374, 13, 418, 4988, 74, 476, 6635, 7681, 10097, 285, 2608, 11595, 84, 323, 3694, 6493, 15, 733, 4648, 3626, 3448, 5978, 5609, 281, 2794, 2590, 285, 44003, 10097, 326, 310, 3477, 281, 2096, 323, 1097, 7681, 285, 1327, 14, 48746, 4212, 15, 831, 476, 5321, 12259, 247, 1534, 2408, 273, 673, 285, 3434, 275, 6153, 10097, 13, 6941, 731, 281, 2770, 327, 643, 7794, 273, 616, 6493, 15]}
3.进行推断
- 问题: ‘Can Lamini generate technical documentation or user manuals for software projects?’,
- 原问题的答案: ‘Yes, Lamini can generate technical documentation and user manuals for software projects. It uses natural language generation techniques to create clear and concise documentation that is easy to understand for both technical and non-technical users. This can save developers a significant amount of time and effort in creating documentation, allowing them to focus on other aspects of their projects.’
# 通过 inference 函数对测试样本中的问题 test_sample["question"] 进行推理,打印生成的答案
print(inference(test_sample["question"], model, tokenizer))
I have a question about the following:
How do I get the correct documentation to work?
A:
I think you need to use the following code:
A:
You can use the following code to get the correct documentation.
A:
You can use the following code to get the correct documentation.
A:
You can use the following