最近新换了工作,以后的工作内容会和大模型相关,所以先抽空跑了一下chatGLM2-6b的demo,使用Qlora或lora微调模型
今天简单写个文档记录一下,顺便也是一个简单的教程,并且踩了qlora loss变成nan训练不稳定的问题
本教程并没有写lora的原理,需要的话自行查阅
1.chatGLM2-6b 模型我已经从huggingface 下载到服务器,因为我的服务器不能直接连接huggingface 下载
我是放到了文件夹下 /data/tmp/chatGLM2_6b_pretrain,包含模型文件和一些配置文件,直接在huggingface下载就好
2.打印模型结构
1 from transformers import AutoModel
2
3 model_name = "/data/tmp/chatGLM2_6b_pretrain"
4 model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
5 print(model)
ChatGLMForConditionalGeneration(
(transformer): ChatGLMModel(
(embedding): Embedding(
(word_embeddings): Embedding(65024, 4096)
)
(rotary_pos_emb): RotaryEmbedding()
(encoder): GLMTransformer(
(layers): ModuleList(
(0-27): 28 x GLMBlock(
(input_layernorm): RMSNorm()
(self_attention): SelfAttention(
(query_key_value): Linear(in_features=4096, out_features=4608, bias=True)
(core_attention): CoreAttention(
(attention_dropout): Dropout(p=0.0, inplace=False)
)
(dense): Linear(in_features=4096, out_features=4096, bias=False)
)
(post_attention_layernorm): RMSNorm()
(mlp): MLP(
(dense_h_to_4h): Linear(in_features=4096, out_features=27392, bias=False)
(dense_4h_to_h): Linear(in_features=13696, out_features=4096, bias=False)
)
)
)
(final_layernorm): RMSNorm()
)
(output_layer): Linear(in_features=4096, out_features=65024, bias=False)
)
)
query_key_value 这个矩阵不是三个方阵拼接到一起,应该是Wq 4096*4096 Wk 4096*256 Wv 4096*256 使用的 group-attention
3.打印添加lora后的模型结构
1 from transformers import AutoTokenizer, AutoModel, AutoConfig
2 from peft import LoraConfig, get_peft_model, TaskType
3
4 model_name = "/data/tmp/chatGLM2_6b_pretrain"
5 model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
6
7 config = LoraConfig(
8 peft_type="LORA",
9 task_type=TaskType.CAUSAL_LM,
10 inference_mode=False,
11 r=8,
12 lora_alpha=16,
13 lora_dropout=0.1,
14 fan_in_fan_out=False,
15 bias='lora_only',
16 target_modules=["query_key_value"]
17 )
18
19 model = get_peft_model(model, config)
20 print(model)
PeftModelForCausalLM(
(base_model): LoraModel(
(model): ChatGLMForConditionalGeneration(
(transformer): ChatGLMModel(
(embedding): Embedding(
(word_embeddings): Embedding(65024, 4096)
)
(rotary_pos_emb): RotaryEmbedding()
(encoder): GLMTransformer(
(layers): ModuleList(
(0-27): 28 x GLMBlock(
(input_layernorm): RMSNorm()
(self_attention): SelfAttention(
(query_key_value): Linear(
in_features=4096, out_features=4608, bias=True
(lora_dropout): ModuleDict(
(default): Dropout(p=0.1, inplace=False)
)
(lora_A): ModuleDict(
(default): Linear(in_features=4096, out_features=8, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=8, out_features=4608, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
)
(core_attention): CoreAttention(
(attention_dropout): Dropout(p=0.0, inplace=False)
)
(dense): Linear(in_features=4096, out_features=4096, bias=False)
)
(post_attention_layernorm): RMSNorm()
(mlp): MLP(
(dense_h_to_4h): Linear(in_features=4096, out_features=27392, bias=False)
(dense_4h_to_h): Linear(in_features=13696, out_features=4096, bias=False)
)
)
)
(final_layernorm): RMSNorm()
)
(output_layer): Linear(in_features=4096, out_features=65024, bias=False)
)
)
)
)
会发现 在query_key_value 矩阵下 多了两个全连接层,lora_A 和 lora_B ,这两个全连接层就是要训练的
4.准备数据集,我们使用的firefly数据集,可以自行去huggingface下载jsonl格式,需要提前划分好训练集和测试集 qa_dataset.py
1 # -*- coding: utf-8 -*-
2 from torch.utils.data import Dataset
3 import torch
4 import json
5 import numpy as np
6
7
8 class QADataset(Dataset):
9 def __init__(self, data_path, tokenizer, max_source_length, max_target_length) -> None:
10 super().__init__()
11 self.tokenizer = tokenizer
12 self.max_source_length = max_source_length
13 self.max_target_length = max_target_length
14 self.max_seq_length = self.max_source_length + self.max_target_length
15
16 self.data = []
17 with open(data_path, "r", encoding='utf-8') as f:
18 for line in f:
19 if not line or line == "":
20 continue
21 json_line = json.loads(line)
22 # {'kind': 'NLI', 'input': '自然语言推理:\n前提:家里人心甘情愿地养他,还有几家想让他做女婿的\n假设:他是被家里人收养的孤儿', 'target': '中立'}
23 kind = json_line["kind"]
24 input = json_line["input"]
25 target = json_line["target"]
26 self.data.append({
27 "question": input,
28 "answer": "--**"+kind+"**--\n"+target
29 })
30 print(