对 ChatGLM-6B 做 LoRA Fine-tuning
ChatGLM-6B 是一个支持中英双语的对话语言模型,基于 GLM (General Language Model)。它只有 62 亿个参数,量化后最低 (INT4 量化) 只需要 6GB 的显存,完全可以部署到消费级显卡上。在实际使用这个模型一段时间以后,我们发现模型的对话表现能力确实非常不错。那么,基于这个模型做 Fine-tuning 就非常有价值了。
声明:
本文提供的所有技术信息,都基于 THUDM/chatglm-6b 的历史版本:
096f3de6b4959ce38bef7bb05f3129c931a3084e
。
源码地址:
搭建依赖环境
安装 PyTorch 环境:
pip install torch torchvision torchaudio
按照 ChatGLM-6B 的官方指导,安装软件依赖环境:
pip install protobuf==3.20.0 transformers==4.26.1 icetk cpm_kernels
为了做 LoRA,还要安装 peft
pip install peft
加载模型和 Tokenizer
from transformers import AutoTokenizer, AutoModel
checkpoint = "THUDM/chatglm-6b"
revision = "096f3de6b4959ce38bef7bb05f3129c931a3084e"
model = AutoModel.from_pretrained(checkpoint, revision=revision, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(checkpoint, revision=revision, trust_remote_code=True)
分析模型结构
模型加载完后,我们可以打印这个 model
和 tokenizer
,建立对模型的基本认知。
首先打印model
:
print(model)
得到如下结果:
ChatGLMForConditionalGeneration(
(transformer): ChatGLMModel(
(word_embeddings): Embedding(150528, 4096)
(layers): ModuleList(
(0-27): 28 x GLMBlock(
(input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
(attention): SelfAttention(
(rotary_emb): RotaryEmbedding()
(query_key_value): Linear(in_features=4096, out_features=12288, bias=True)
(dense): Linear(in_features=4096, out_features=4096, bias=True)
)
(post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
(mlp): GLU(
(dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)
(dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)
)
)
)
(final_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
)
(lm_head): Linear(in_features=4096, out_features=150528, bias=False)
)
简单分析这个模型结构,至少可以得到如下一些信息:
- 模型使用了 Transformer 结构,因此可以使用 LoRA 进行 Fine-tuning
- 从 Word Embedding 层可以看出,词汇表大小是
150528
- LoRA 可以操作的目标是:
query_key_value
再打印tokenizer
:
print(tokenizer)
得到如下结果(为了便于阅读,已对结果做了分行处理):
ChatGLMTokenizer(
name_or_path='THUDM/chatglm-6b',
vocab_size=150344,
model_max_length=2048,
is_fast=False,
padding_side='left',
truncation_side='right',
special_tokens={
'bos_token': '<sop>',
'eos_token': '</s>',
'unk_token': '<unk>',
'pad_token': '<pad>',
'mask_token': '[MASK]'
}
)
这里有几个可以关注的点:
- 词汇表大小
vocab_size
是150344
- 不是一个 fast Tokenizer(
is_fast
的值是False
) - 特殊 token 包括:
bos
eos
pad
和mask
为什么 model 中的词汇表大小是 150528
,而 tokenizer
中定义的词汇表大小却是 150344
呢?读者可以带着这个疑问去读一读模型项目的源码,看看能不能找到答案。
配置 LoRA
借助 peft 库,我们可以很方便地对模型注入 LoRA。
from peft import LoraConfig, get_peft_model, TaskType
def load_lora_config(model):
config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1,
target_modules=["query_key_value"]
)
return get_peft_model(model, config)
model = load_lora_config(model)
打印可训练的参数量:
model.print_trainable_parameters()
得到如下结果:
trainable params: 3670016 || all params: 6258876416 || trainable%: 0.05863697820615348
可以看到,总的参数量是 6,258,876,416
,可训练的参数量是 3,670,016
,占比 0.0586%
左右。训练参数量只是百万级别的,可谓相当友好了!另外需要注意的一点是,ChatGLM-6B 是一个因果语言模型 (Causal Language Model),因此我们这里选择的任务类型是 CAUSAL_LM
。
构建数据集
定义常量
构建之前,我们先定义几个特殊 Token 常量:
bos = tokenizer.bos_token_id
eop = tokenizer.eop_token_id
pad = tokenizer.pad_token_id
mask = tokenizer.mask_token_id
gmask = tokenizer.sp_tokenizer[tokenizer.gMASK_token]
将这几个值打印出来:
print("bos = ", bos)
print("eop = ", eop)
print("pad = ", pad)
print("mask = ", mask)
print("gmask = ", gmask)
得到如下结果:
bos = 150004
eop = 150005
pad = 20003
mask = 150000
gmask = 150001
我们也可以直接用这个常量结果替换动态计算的部分。常量修改后的结果变成:
bos = 150004
eop = 150005
pad = 20003
mask = 150000
gmask = 150001
除了上面定义的 Token 常量,我们还需要定义模型训练绑定的设备名,以及最大输入长度和最大输出长度等,如下:
device = "cuda"
max_src_length = 200
max_dst_length = 500
开发者可以结合自己的显卡性能和要处理的数据集特点来确定这些最大长度。
测试 Tokenizer 的编解码
我们可以先做个简单的测试:
text = "AI探险家"
print(tokenizer.encode(text, add_special_tokens = True))
print(tokenizer.encode(text, add_special_tokens = False))
输出结果是:
[26738, 98715, 83920, 150001, 150004]
[26738, 98715, 83920]
从这个结果可以看出,“AI探险家”这几个字的裸编码是 [26738, 98715, 83920]
。为什么是这样呢?我们可以对每一个数值再解码,看看输出结果:
print(tokenizer.decode([26738]))
print(tokenizer.decode([98715]))
print(tokenizer.decode([83920]))
输出结果是:
AI
探险
家
观察这个结果,读者应该能对词汇表建立基本的认知了。读者如果有兴趣,还可以分别针对 “A” “I” “探” “险” 这几个字分别编码,看看编码结果是什么。
另外,当 add_special_tokens = True
时,编码结果会在末尾添加 150001
和 150004
,也就是 gmask
和 bos
。请注意,我们的训练数据,要按照如下编码要求进行构造:
[token, ..., token, gmask, bos, token, ... token, eop]
因此,前半部分文本的编码可以直接让 add_special_tokens = True
,后半部分文本的编码则让 add_special_tokens = False
,最后再拼接一个 eop
。
定义 Prompt
我们 Fine-tuning 的任务是问答任务(简称 QA),因此一个简单的 Prompt 是这样的:
PROMPT_PATTERN = "问:{}\n答: "
{}
里填入 QA 训练集的问题文本。在显存有限的情况下,如果不对长文本做限制处理,很容易出现类似 CUDA out of memory
这样的报错。处理长文本,在给定编码后的数组上限时,可能存在这么几种方式:
- 截断末尾超出部分的编码
- 截断前面超出部分的编码
- 丢掉训练样本
每一种方式都有各自的优劣,开发者可以根据自身数据的特点自行选择一种处理方式。当然,如果你的显存够大,也可以不处理。本文以上述第一种方式进行处理。
为了不把 PROMPT_PATTERN
中的 \n答:
这几个字截断掉,我们将整个 PROMPT_PATTERN
拆成两部分:
PROMPT_PATTERN = "问:{}"
SEP_PATTERN = "\n答: "
基于这份 Prompt 模板,我们定义下面三个辅助方法:
def create_prompt(question):
return PROMPT_PATTERN.format(question), SEP_PATTERN
def create_prompt_ids(tokenizer, question, max_src_length):
prompt, sep = create_prompt(question)
sep_ids = tokenizer.encode(
sep,
add_special_tokens = True
)
sep_len = len(sep_ids)
special_tokens_num = 2
prompt_ids = tokenizer.encode(
prompt,
max_length = max_src_length - (sep_len - special_tokens_num),
truncation = True,
add_special_tokens = False
)
return prompt_ids + sep_ids
def create_inputs_and_labels(tokenizer, question, answer, device):
prompt = create_prompt_ids(tokenizer, question, max_src_length)
completion = tokenizer.encode(
answer,
max_length = max_dst_length,
truncation = True,
add_special_tokens = False
)
inputs = prompt + completion + [eop]
labels = [-100] * len(prompt) + completion + [eop]
inputs = torch.tensor(inputs, dtype=torch.long, device=device)
labels = torch.tensor(labels, dtype=torch.long, device=device)
return inputs, labels
值得注意的两点:
- 从
create_prompt_ids
这个函数实现可以看出,我们编码分隔符SEP_PATTERN
时自动添加了前面所述的 2 个特殊 Token。 - 对
create_inputs_and_labels
的函数实现中,我们将labels
无需处理的部分用数值-100
来表示。因为ChatGLMForConditionalGeneration
内部在计算损失函数的时候,用的是torch.nn.CrossEntropyLoss
。该函数的参数之一ignore_index
默认值是-100
。这就让我们在计算损失函数时,无需考虑非标识部分的数值。
构建 Attention Mask 和 Position IDs
def get_attention_mask(tokenizer, input_ids, device):
seq = input_ids.tolist()
context_len = seq.index(bos)
seq_len = len(seq)
attention_mask = torch.ones((seq_len, seq_len), device=device)
attention_mask.tril_()
attention_mask[..., :context_len] = 1
attention_mask.unsqueeze_(0)
attention_mask = (attention_mask < 0.5).bool()
return attention_mask
def get_position_ids(tokenizer, input_ids, device, position_encoding_2d=True):
seq = input_ids.tolist()
context_len = seq.index(bos)
seq_len = len(seq)
mask_token = mask if mask in seq else gmask
use_gmask = False if mask in seq else gmask
mask_position = seq.index(mask_token)
if position_encoding_2d:
position_ids = torch.arange(seq_len, dtype=torch.long, device=device)
if not use_gmask:
position_ids[context_len:] = mask_position
block_position_ids = torch.cat((
torch.zeros(context_len, dtype=torch.long, device=device),
torch.arange(seq_len - context_len, dtype=torch.long, device=device) + 1
))
position_ids = torch.stack((position_ids, block_position_ids), dim=0)
else:
position_ids = torch.arange(seq_len, dtype=torch.long, device=device)
if not use_gmask:
position_ids[context_len:] = mask_position
return position_ids
在这个通用实现中,我们针对 mask
和 gmask
两种情况做了区分,同时也对是否执行 position_encoding_2d
分情况处理。本文的 QA 任务采用的是 gmask
,并且使用 position_encoding_2d = True
。
我们可以构建下面的问答,来验证下这几个函数的输出:
test_data = {
"question": "AI探险家帅不帅?",
"answer": "非常帅!"
}
inputs, labels = create_inputs_and_labels(tokenizer, **test_data, device=device)
attention_mask = get_attention_mask(tokenizer, inputs, device=device)
position_ids = get_position_ids(tokenizer, inputs, device=device)
print("inputs: \n", inputs.tolist())
print("\nlabels: \n", labels.tolist())
print("\nposition_ids: \n", position_ids.tolist())
print("\nattention_mask: \n", attention_mask.tolist())
输出结果(为了便于阅读,已对输出进行格式化操作):
inputs:
[20005, 84286, 20012, 31943, 98715, 83920, 87359, 83848, 87359, 20031, 20005, 20004, 87342, 20012, 150001, 150004, 20005, 84122, 87359, 20035, 150005]
labels:
[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 20005, 84122, 87359, 20035, 150005]
position_ids:
[
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5]
]
attention_mask:
[[
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]]]
结合论文观察数据,基本符合预期。
创建数据集
我们先定义具有如下格式的训练数据:
train_data = [
{"question": "问题1", "answer": "答案1"},
{"question": "问题2", "answer": "答案2"},
]
定义好格式后,我们先创建一个 QADataset
类,如下:
from torch.utils.data import Dataset
class QADataset(Dataset):
def __init__(self, data, tokenizer) -> None:
super().__init__()
self.data = data
self.tokenizer = tokenizer
def __getitem__(self, index):
item_data = self.data[index]
tokenizer = self.tokenizer
input_ids, labels = create_inputs_and_labels(
tokenizer,
device=device,
**item_data
)
attention_mask = get_attention_mask(tokenizer, input_ids, device)
position_ids = get_position_ids(tokenizer, input_ids, device)
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"position_ids": position_ids
}
def __len__(self):
return len(self.data)
然后创建一个 Data Collator:
def collate_fn(batch):
input_ids = []
attention_mask = []
labels = []
position_ids = []
for obj in batch:
input_ids.append(obj['input_ids'])
labels.append(obj['labels'])
attention_mask.append(obj['attention_mask'])
position_ids.append(obj['position_ids'])
return {
'input_ids': torch.stack(input_ids),
'attention_mask': torch.stack(attention_mask),
'labels': torch.stack(labels),
'position_ids':torch.stack(position_ids)
}
开始训练
from transformers import TrainingArguments, Trainer
model.to(device)
training_args = TrainingArguments(
"output",
fp16 =True,
save_steps = 500,
save_total_limit = 3,
gradient_accumulation_steps=1,
per_device_train_batch_size = 1,
learning_rate = 1e-4,
max_steps=1500,
logging_steps=50,
remove_unused_columns=False,
seed=0,
data_seed=0,
group_by_length=False,
dataloader_pin_memory=False
)
class ModifiedTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
return model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
position_ids=inputs["position_ids"],
labels=inputs["labels"],
).loss
train_dataset = QADataset(train_data, tokenizer=tokenizer)
trainer = ModifiedTrainer(
model=model,
train_dataset=train_dataset,
args=training_args,
data_collator=collate_fn,
tokenizer=tokenizer
)
trainer.train()
预测
response, history = model.chat(tokenizer, "AI探险家的颜值如何?", history=[])
print(response)
保存训练模型
import os
def save_tuned_parameters(model, path):
saved_params = {
k: v.to(device)
for k, v in model.named_parameters()
if v.requires_grad
}
torch.save(saved_params, path)
save_tuned_parameters(model, os.path.join("/path/to/output", "chatglm-6b-lora.pt"))
重载训练后的模型
checkpoint = "THUDM/chatglm-6b"
revision = "096f3de6b4959ce38bef7bb05f3129c931a3084e"
model = AutoModel.from_pretrained(checkpoint, revision=revision, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(checkpoint, revision=revision, trust_remote_code=True)
model = load_lora_config(model)
model.load_state_dict(torch.load(f"/path/to/output/chatglm-6b-lora.pt"), strict=False)
model.half().cuda().eval()
response, history = model.chat(tokenizer, "AI探险家的颜值如何?", history=[])
print(response)