准备工作
- 准备特定场景微调数据1000左右的量级作为训练集,300左右的数据作为测试集(最好是英文版),数据集的格式示例如下:
// 判别式分类任务
[
{
"instrcution": "# Goal\nYou are a senior python programmer, please check is there any risk of endless loop in the code snippet.\n#code\n```python\nint n = 0\nwhile n < 100:\ns = get_token();\n```",
"input": "",
"output": "{\"label\":1}"
},
//...
]
//文本生成式任务
[
{
"instrcution": "# Goal\nYou are a senior python programmer, please write a binary search algorithm.",
"input": "",
"output": "def binary_search(arr, target):\nleft = 0\nright = len(arr) - 1\nwhile left <= right:\nmid = left + (right - left) // 2\nif arr[mid] == target:\n. return mid\nelif arr[mid] > target:\nright = mid - 1\nelse:\nleft = mid + 1\nreturn -1"
},
//...
]
- 下载模型至本地:linux系统下载huggingface大模型教程;
- 可以基于开源的微调框架进行:
# unsloth是一个比较轻量级的框架,适合开发者从0-1自定义微调路径
# 本实战教程是基于unsloth框架的
https://github.com/unslothai/unsloth
# 里面有很多微调教程、量化的大模型、数据集可参考使用
# 其他主流微调框架,功能全面,可视化完善
# https://github.com/huggingface/trl
# https://github.com/OpenAccess-AI-Collective/axolotl
# https://github.com/hiyouga/LLaMA-Factory
# https://github.com/modelscope/swift
# ...还有很多其他框架,百花齐放
- python环境准备:
查看系统GPU配置信息:
import torch
major_version, minor_version = torch.cuda.get_device_capability()
print(major_version)
根据配置信息安装指定的包
# major_version>=8时,GPU环境一般为Ampere, Hopper GPUs (RTX 30xx, RTX 40xx, A100, H100, L40)等,需要根据以下语句安装环境
pip3 install --no-deps packaging ninja einops flash-attn xformers trl peft accelerate bitsandbytes
# major_version<8时,GPU环境一般为(V100, Tesla T4, RTX 20xx)等,需要根据以下语句安装环境
pip3 install --no-deps xformers trl peft accelerate bitsandbytes
导包如果报错的话,大概率是包版本的冲突,如果重装一些包还是存在冲突,实在解决不了,就新建一个conda环境,按提示一个一个装
LoRA微调
- step1:样本转换,json转换至jsonl的脚本cover_alpaca2jsonl.py
import argparse
import json
from tqdm import tqdm
def format_example(example: dict) -> dict:
context = f"Instruction: {
example['instruction']}\n"
if example.get("input"):
context += f"Input: {
example['input']}\n"
context += "Answer: "
target = example["output"]
return {
"context": context, "target": target}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--data_path", type=str, default="data/train_data/samples_train_llama.json")
parser.add_argument("--save_path", type=str, default="data/train_data/samples_train_llama.jsonl")
args = parser.parse_args()
with open(args.data_path) as f:
examples = json.load(f)
with open(args.save_path, 'w') as f:
for example in tqdm(examples, desc="formatting.."):
f.write(json.dumps(format_example(example)) + '\n')
if __name__ == "__main__":
main()
- step2:生产tokenize文件夹,为后续自定义损失函数做数据准备,脚本tokenize_dataset_rows.py
import argparse
import json
from tqdm import tqdm
import datasets
import transformers
import os
def preprocess(tokenizer, config, example, max_seq_length, version):
if version == 'v1':
prompt = example["context"]
target = example["target"]
prompt_ids = tokenizer.encode(prompt, max_length=max_seq_length, truncation=True)
target_ids = tokenizer.encode(
target,
max_length=max_seq_length,
truncation=True,
add_special_tokens=False)
input_ids = prompt_ids + target_ids + [config.eos_token_id]
return {
"input_ids": input_ids, "seq_len": len(prompt_ids)}
if version == 'v2':
query = example["context"]
target = example["target"]
history = None
prompt = tokenizer.build_prompt(query, history)
a_ids = tokenizer.encode(text=prompt, add_special_tokens=True, truncation=True,
max_length=max_seq_length)
b_ids = tokenizer.encode(text=target, add_special_tokens=False, truncation=True,
max_length=max_seq_length)
input_ids = a_ids + b_ids + [tokenizer.eos_token_id]
return {
"input_ids": input_ids, "seq_len": len(a_ids)}
def read_jsonl(path, max_seq_length, base_model_path, version='v1', skip_overlength=False):
tokenizer = transformers.AutoTokenizer.from_pretrained(
base_model_path, trust_remote_code=True)
config = transformers.AutoConfig.from_pretrained(
base_model_path, trust_remote_code=True, device_map='auto')
with open(path, "r") as f:
for line in tqdm(f.readlines()):
example = json.loads(line)
# feature = preprocess(tokenizer, config, example, max_seq_length)
feature = preprocess(tokenizer, config, example, max_seq_length, version)
if skip_overlength and