环境搭建
centos7.9上的miniconda环境(安装略),创建一个python版本3.11.0的环境.
conda create -n lorafinetune python==3.11.0 #创建命令
以下是requirements.txt内容:
# This file may be used to create an environment using: # $ conda create --name <env> --file <this file> # platform: linux-64 _libgcc_mutex=0.1=main _openmp_mutex=5.1=1_gnu accelerate=0.19.0=pypi_0 bzip2=1.0.8=h7b6447c_0 ca-certificates=2023.01.10=h06a4308_0 certifi=2023.5.7=pypi_0 charset-normalizer=3.1.0=pypi_0 cmake=3.26.3=pypi_0 cpm-kernels=1.0.11=pypi_0 filelock=3.12.0=pypi_0 fsspec=2023.5.0=pypi_0 huggingface-hub=0.14.1=pypi_0 icetk=0.0.4=pypi_0 idna=3.4=pypi_0 jinja2=3.1.2=pypi_0 ld_impl_linux-64=2.38=h1181459_1 libffi=3.4.4=h6a678d5_0 libgcc-ng=11.2.0=h1234567_1 libgomp=11.2.0=h1234567_1 libstdcxx-ng=11.2.0=h1234567_1 libuuid=1.41.5=h5eee18b_0 lit=16.0.3=pypi_0 markupsafe=2.1.2=pypi_0 mpmath=1.3.0=pypi_0 ncurses=6.4=h6a678d5_0 networkx=3.1=pypi_0 numpy=1.24.3=pypi_0 nvidia-cublas-cu11=11.10.3.66=pypi_0 nvidia-cuda-cupti-cu11=11.7.101=pypi_0 nvidia-cuda-nvrtc-cu11=11.7.99=pypi_0 nvidia-cuda-runtime-cu11=11.7.99=pypi_0 nvidia-cudnn-cu11=8.5.0.96=pypi_0 nvidia-cufft-cu11=10.9.0.58=pypi_0 nvidia-curand-cu11=10.2.10.91=pypi_0 nvidia-cusolver-cu11=11.4.0.1=pypi_0 nvidia-cusparse-cu11=11.7.4.91=pypi_0 nvidia-nccl-cu11=2.14.3=pypi_0 nvidia-nvtx-cu11=11.7.91=pypi_0 openssl=1.1.1t=h7f8727e_0 packaging=23.1=pypi_0 peft=0.3.0=pypi_0 pillow=9.5.0=pypi_0 pip=23.0.1=pypi_0 protobuf=3.20.0=pypi_0 psutil=5.9.5=pypi_0 python=3.11.0=h7a1cb2a_3 pyyaml=6.0=pypi_0 readline=8.2=h5eee18b_0 regex=2023.5.5=pypi_0 requests=2.30.0=pypi_0 sentencepiece=0.1.99=pypi_0 setuptools=66.0.0=pypi_0 sqlite=3.41.2=h5eee18b_0 sympy=1.12=pypi_0 tk=8.6.12=h1ccaba5_0 tokenizers=0.13.3=pypi_0 torch=2.0.1=pypi_0 torchaudio=2.0.2=pypi_0 torchvision=0.15.2=pypi_0 tqdm=4.65.0=pypi_0 transformers=4.27.1=pypi_0 #开始 4.26.1 报错注意版本号 triton=2.0.0=pypi_0 typing-extensions=4.5.0=pypi_0 tzdata=2023c=h04d1e81_0 urllib3=2.0.2=pypi_0 wheel=0.38.4=pypi_0 xz=5.4.2=h5eee18b_0 zlib=1.2.13=h5eee18b_0 ##运行pip安装组件的命令 pip install -r requirments.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
加载模型和 Tokenizer
from transformers import AutoTokenizer, AutoModel checkpoint = "/home/yangxin/workspace/hom/chatglm-6b" #指向我下载的chatglm-6b模型位置 revision = "096f3de6b4959ce38bef7bb05f3129c931a3084e" #这个用git log 打印一下版本取最近的一个 model = AutoModel.from_pretrained(checkpoint, revision=revision, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(checkpoint, revision=revision, trust_remote_code=True) #以上加载模型和Tokenizer部分 print(model) #打印模型 print(tokenizer) #打印tokenizer bos = tokenizer.bos_token_id # eop = tokenizer. eop = tokenizer.eos_token_id pad = tokenizer.pad_token_id mask = tokenizer.sp_tokenizer[tokenizer.mask_token] gmask = tokenizer.sp_tokenizer[tokenizer.gmask_token] print("bos = ", bos) # print("eop = ", eop) print("pad = ", pad) print("mask = ", mask) print("gmask = ", gmask) #取参数并打印出来
模型结构
/home/yangxin/miniconda3/envs/lorafinetune/bin/python /home/yangxin/workspace/hom/chatglm-finetune/my/jiazaitokenizer.py Loading checkpoint shards: 100%|██████████| 8/8 [00:02<00:00, 2.71it/s] ChatGLMForConditionalGeneration( (transformer): ChatGLMModel( (word_embeddings): Embedding(130528, 4096) (layers): ModuleList( (0-27): 28 x GLMBlock( (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True) (attention): SelfAttention( (rotary_emb): RotaryEmbedding() (query_key_value): Linear(in_features=4096, out_features=12288, bias=True) (dense): Linear(in_features=4096, out_features=4096, bias=True) ) (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True) (mlp): GLU( (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True) (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True) ) ) ) (final_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True) ) (lm_head): Linear(in_features=4096, out_features=130528, bias=False) )
简单分析这个模型结构,至少可以得到如下一些信息:
-
模型使用了 Transformer 结构,因此可以使用 LoRA 进行 Fine-tuning
-
从 Word Embedding 层可以看出,词汇表大小是
150528
-
LoRA 可以操作的目标是:
query_key_value
/home/yangxin/miniconda3/envs/lorafinetune/bin/python /home/yangxin/workspace/hom/chatglm-finetune/my/jiazaitokenizer.py Loading checkpoint shards: 100%|██████████| 8/8 [00:02<00:00, 2.76it/s] ChatGLMTokenizer(name_or_path='/home/yangxin/workspace/hom/chatglm-6b', vocab_size=130344, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<sop>', 'eos_token': '<eop>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'mask_token': '[MASK]'})
这里有几个可以关注的点:
-
词汇表大小
vocab_size
是130344
-
不是一个 fast Tokenizer(
is_fast
的值是False
) -
特殊 token 包括:
bos
eos
pad
和mask
jiazaitokenizer.py 这个程序除了打印模型和tokenizer还打印了finetune程序需要用到的参数:
from transformers import AutoTokenizer, AutoModel checkpoint = "/home/yangxin/workspace/hom/chatglm-6b" revision = "096f3de6b4959ce38bef7bb05f3129c931a3084e" model = AutoModel.from_pretrained(checkpoint, revision=revision, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(checkpoint, revision=revision, trust_remote_code=True) print(model) print(tokenizer) bos = tokenizer.bos_token_id # eop = tokenizer. eop = tokenizer.eos_token_id pad = tokenizer.pad_token_id mask = tokenizer.sp_tokenizer[tokenizer.mask_token] gmask = tokenizer.sp_tokenizer[tokenizer.gmask_token] print("bos = ", bos) # print("eop = ", eop) print("pad = ", pad) print("mask = ", mask) print("gmask = ", gmask)
配置 LoRA
用这个程序能打印出来参数
from peft import LoraConfig, get_peft_model, TaskType from transformers import AutoTokenizer, AutoModel checkpoint = "/home/yangxin/workspace/hom/chatglm-6b" revision = "096f3de6b4959ce38bef7bb05f3129c931a3084e" model = AutoModel.from_pretrained(checkpoint, revision=revision, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(checkpoint, revision=revision, trust_remote_code=True) def load_lora_config(model): config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1, target_modules=["query_key_value"] ) return get_peft_model(model, config) model = load_lora_config(model) model.print_trainable_parameters()
运行结果:
/home/yangxin/miniconda3/envs/lorafinetune/bin/python /home/yangxin/workspace/hom/chatglm-finetune/my/lora.py Loading checkpoint shards: 100%|██████████| 8/8 [00:02<00:00, 2.73it/s] trainable params: 3670016 || all params: 6176956416 || trainable%: 0.05941463324063059
可以看到,总的参数量是 6176956416
,可训练的参数量是 3670016
,占比 0.0594%
左右。训练参数量只是百万级别的,可谓相当友好了!另外需要注意的一点是,ChatGLM-6B 是一个因果语言模型 (Causal Language Model),因此我们这里选择的任务类型是 CAUSAL_LM
。
构建数据集
定义常量
上文的jiazaitokenizer.py程序打印的变量用到finetune定义常量
bos = 130004 eop = 130005 pad = 3 mask = 130000 gmask = 130001 device = "cuda" max_src_length = 200 max_dst_length = 500
除了上面定义的 Token 常量,我们还需要定义模型训练绑定的设备名,以及最大输入长度和最大输出长度等,如下:
device = "cuda" max_src_length = 200 max_dst_length = 500 PYTHON开发者可以结合自己的显卡性能和要处理的数据集特点来确定这些最大长度。
定义 Prompt
我们 Fine-tuning 的任务是问答任务(简称 QA),因此一个简单的 Prompt 是这样的:
PROMPT_PATTERN = "问:{}\n答: " PYTHON
{}
里填入 QA 训练集的问题文本。在显存有限的情况下,如果不对长文本做限制处理,很容易出现类似CUDA out of memory
这样的报错。处理长文本,在给定编码后的数组上限时,可能存在这么几种方式:
截断末尾超出部分的编码
截断前面超出部分的编码
丢掉训练样本
每一种方式都有各自的优劣,开发者可以根据自身数据的特点自行选择一种处理方式。当然,如果你的显存够大,也可以不处理。本文以上述第一种方式进行处理。 为了不把
PROMPT_PATTERN
中的\n答:
这几个字截断掉,我们将整个PROMPT_PATTERN
拆成两部分:PROMPT_PATTERN = "问:{}" SEP_PATTERN = "\n答: " PYTHON基于这份 Prompt 模板,我们定义下面三个辅助方法:
def create_prompt(question): return PROMPT_PATTERN.format(question), SEP_PATTERN def create_prompt_ids(tokenizer, question, max_src_length): prompt, sep = create_prompt(question) sep_ids = tokenizer.encode( sep, add_special_tokens = True ) sep_len = len(sep_ids) special_tokens_num = 2 prompt_ids = tokenizer.encode( prompt, max_length = max_src_length - (sep_len - special_tokens_num), truncation = True, add_special_tokens = False ) return prompt_ids + sep_ids def create_inputs_and_labels(tokenizer, question, answer, device): prompt = create_prompt_ids(tokenizer, question, max_src_length) completion = tokenizer.encode( answer, max_length = max_dst_length, truncation = True, add_special_tokens = False ) inputs = prompt + completion + [eop] labels = [-100] * len(prompt) + completion + [eop] inputs = torch.tensor(inputs, dtype=torch.long, device=device) labels = torch.tensor(labels, dtype=torch.long, device=device) return inputs, labels PYTHON值得注意的两点:
从
create_prompt_ids
这个函数实现可以看出,我们编码分隔符SEP_PATTERN
时自动添加了前面所述的 2 个特殊 Token。对
create_inputs_and_labels
的函数实现中,我们将labels
无需处理的部分用数值-100
来表示。因为ChatGLMForConditionalGeneration
内部在计算损失函数的时候,用的是torch.nn.CrossEntropyLoss
。该函数的参数之一ignore_index
默认值是-100
。这就让我们在计算损失函数时,无需考虑非标识部分的数值。
构建 Attention Mask 和 Position IDs
略
创建数据集
格式是:
train_data = [ {"question": "问题1", "answer": "答案1"}, {"question": "问题2", "answer": "答案2"}, ]
训练数据集
finetune.py训练数据集的程序:
import torch import os import json from transformers import AutoTokenizer, AutoModel from torch.cuda.amp import autocast from utils import load_lora_config import time time_start = time.time() #开始计时 checkpoint = "/home/yangxin/workspace/hom/chatglm-6b" # revision = "096f3de6b4959ce38bef7bb05f3129c931a3084e" revision = "942945df047dee66f653c68ae0e56655045f1741" model = AutoModel.from_pretrained(checkpoint, revision = revision, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(checkpoint, revision = revision, trust_remote_code=True) model = load_lora_config(model) bos = 130004 eop = 130005 pad = 3 mask = 130000 gmask = 130001 device = "cuda" max_src_length = 200 max_dst_length = 500 PROMPT_PATTERN = "问:{}" SEP_PATTERN = "\n答: " def create_prompt(question): return PROMPT_PATTERN.format(question), SEP_PATTERN def create_prompt_ids(tokenizer, question, max_src_length): prompt, sep = create_prompt(question) sep_ids = tokenizer.encode( sep, add_special_tokens = True ) sep_len = len(sep_ids) special_tokens_num = 2 prompt_ids = tokenizer.encode( prompt, max_length = max_src_length - (sep_len - special_tokens_num), truncation = True, add_special_tokens = False ) return prompt_ids + sep_ids def create_inputs_and_labels(tokenizer, question, answer, device): prompt = create_prompt_ids(tokenizer, question, max_src_length) completion = tokenizer.encode( answer, max_length = max_dst_length, truncation = True, add_special_tokens = False ) inputs = prompt + completion + [eop] labels = [-100] * len(prompt) + completion + [eop] inputs = torch.tensor(inputs, dtype=torch.long, device=device) labels = torch.tensor(labels, dtype=torch.long, device=device) return inputs, labels def get_attention_mask(tokenizer, input_ids, device): seq = input_ids.tolist() context_len = seq.index(bos) seq_len = len(seq) attention_mask = torch.ones((seq_len, seq_len), device=device) attention_mask.tril_() attention_mask[..., :context_len] = 1 attention_mask.unsqueeze_(0) attention_mask = (attention_mask < 0.5).bool() return attention_mask def get_position_ids(tokenizer, input_ids, device, position_encoding_2d=True): seq = input_ids.tolist() context_len = seq.index(bos) seq_len = len(seq) mask_token = mask if mask in seq else gmask use_gmask = False if mask in seq else gmask mask_position = seq.index(mask_token) if position_encoding_2d: position_ids = torch.arange(seq_len, dtype=torch.long, device=device) if not use_gmask: position_ids[context_len:] = mask_position block_position_ids = torch.cat(( torch.zeros(context_len, dtype=torch.long, device=device), torch.arange(seq_len - context_len, dtype=torch.long, device=device) + 1 )) position_ids = torch.stack((position_ids, block_position_ids), dim=0) else: position_ids = torch.arange(seq_len, dtype=torch.long, device=device) if not use_gmask: position_ids[context_len:] = mask_position return position_ids from torch.utils.data import Dataset class QADataset(Dataset): def __init__(self, data, tokenizer) -> None: super().__init__() self.data = data self.tokenizer = tokenizer def __getitem__(self, index): item_data = self.data[index] tokenizer = self.tokenizer input_ids, labels = create_inputs_and_labels( tokenizer, device=device, **item_data ) attention_mask = get_attention_mask(tokenizer, input_ids, device) position_ids = get_position_ids(tokenizer, input_ids, device) return { "input_ids": input_ids, "labels": labels, "attention_mask": attention_mask, "position_ids": position_ids } def __len__(self): return len(self.data) def collate_fn(batch): input_ids = [] attention_mask = [] labels = [] position_ids = [] for obj in batch: input_ids.append(obj['input_ids']) labels.append(obj['labels']) attention_mask.append(obj['attention_mask']) position_ids.append(obj['position_ids']) return { 'input_ids': torch.stack(input_ids), 'attention_mask': torch.stack(attention_mask), 'labels': torch.stack(labels), 'position_ids':torch.stack(position_ids) } train_data = [] with open("data/dbstd.json", encoding="utf-8") as file: train_data = json.load(file) from transformers import TrainingArguments, Trainer model.to(device) training_args = TrainingArguments( "output", fp16 =True, save_steps = 500, save_total_limit = 3, gradient_accumulation_steps=1, per_device_train_batch_size = 1, learning_rate = 1e-4, max_steps=1500, logging_steps=50, remove_unused_columns=False, seed=0, data_seed=0, group_by_length=False, dataloader_pin_memory=False ) class ModifiedTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False): return model( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], position_ids=inputs["position_ids"], labels=inputs["labels"], ).loss train_dataset = QADataset(train_data, tokenizer=tokenizer) trainer = ModifiedTrainer( model=model, train_dataset=train_dataset, args=training_args, data_collator=collate_fn, tokenizer=tokenizer ) trainer.train() def save_tuned_parameters(model, path): saved_params = { k: v.to(device) for k, v in model.named_parameters() if v.requires_grad } torch.save(saved_params, path) save_tuned_parameters(model, os.path.join("dbstd", "chatglm-6b-lora.pt")) #保存模型 time_end = time.time() #结束计时 time_c= time_end - time_start #运行所花时间 print('time cost', time_c, 's') #调用方法:python finetune.py
运行结果:
Loading checkpoint shards: 100%|██████████| 8/8 [00:02<00:00, 2.76it/s] trainable params: 3670016 || all params: 6176956416 || trainable%: 0.05941463324063059 /home/yangxin/miniconda3/envs/lorafinetune/lib/python3.11/site-packages/transformers/optimization.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning warnings.warn( 3%|▎ | 50/1500 [00:05<02:07, 11.36it/s]{'loss': 1.788, 'learning_rate': 9.693333333333335e-05, 'epoch': 0.07} 7%|▋ | 100/1500 [00:09<02:04, 11.22it/s]{'loss': 0.6055, 'learning_rate': 9.360000000000001e-05, 'epoch': 0.13} 10%|█ | 150/1500 [00:14<02:01, 11.14it/s]{'loss': 0.4876, 'learning_rate': 9.026666666666666e-05, 'epoch': 0.2} 13%|█▎ | 200/1500 [00:18<01:54, 11.36it/s]{'loss': 0.4849, 'learning_rate': 8.693333333333334e-05, 'epoch': 0.27} 17%|█▋ | 250/1500 [00:23<01:52, 11.16it/s]{'loss': 0.3971, 'learning_rate': 8.36e-05, 'epoch': 0.33} 20%|██ | 300/1500 [00:27<01:49, 10.96it/s]{'loss': 0.4122, 'learning_rate': 8.026666666666666e-05, 'epoch': 0.4} 23%|██▎ | 350/1500 [00:32<01:42, 11.20it/s]{'loss': 0.3522, 'learning_rate': 7.693333333333334e-05, 'epoch': 0.46} 27%|██▋ | 400/1500 [00:36<01:36, 11.43it/s]{'loss': 0.4039, 'learning_rate': 7.36e-05, 'epoch': 0.53} 30%|███ | 450/1500 [00:40<01:33, 11.26it/s]{'loss': 0.38, 'learning_rate': 7.026666666666668e-05, 'epoch': 0.6} 33%|███▎ | 500/1500 [00:45<01:27, 11.39it/s]{'loss': 0.4503, 'learning_rate': 6.693333333333334e-05, 'epoch': 0.66} 37%|███▋ | 550/1500 [00:57<01:25, 11.07it/s]{'loss': 0.5087, 'learning_rate': 6.36e-05, 'epoch': 0.73} 40%|████ | 600/1500 [01:02<01:21, 11.09it/s]{'loss': 0.3129, 'learning_rate': 6.026666666666667e-05, 'epoch': 0.8} 43%|████▎ | 650/1500 [01:06<01:11, 11.96it/s]{'loss': 0.4128, 'learning_rate': 5.693333333333334e-05, 'epoch': 0.86} 47%|████▋ | 700/1500 [01:10<01:06, 12.00it/s]{'loss': 0.3357, 'learning_rate': 5.360000000000001e-05, 'epoch': 0.93} 50%|█████ | 750/1500 [01:15<01:02, 11.93it/s]{'loss': 0.3897, 'learning_rate': 5.026666666666667e-05, 'epoch': 1.0} 53%|█████▎ | 800/1500 [01:19<01:01, 11.38it/s]{'loss': 0.3037, 'learning_rate': 4.6933333333333333e-05, 'epoch': 1.06} 57%|█████▋ | 850/1500 [01:23<00:57, 11.39it/s]{'loss': 0.3512, 'learning_rate': 4.36e-05, 'epoch': 1.13} 60%|██████ | 900/1500 [01:28<00:50, 11.92it/s]{'loss': 0.375, 'learning_rate': 4.026666666666667e-05, 'epoch': 1.2} 63%|██████▎ | 950/1500 [01:32<00:47, 11.60it/s]{'loss': 0.2468, 'learning_rate': 3.6933333333333334e-05, 'epoch': 1.26} 67%|██████▋ | 1000/1500 [01:36<00:42, 11.65it/s]{'loss': 0.3056, 'learning_rate': 3.3600000000000004e-05, 'epoch': 1.33} 70%|███████ | 1050/1500 [01:48<00:38, 11.75it/s]{'loss': 0.3082, 'learning_rate': 3.0266666666666666e-05, 'epoch': 1.39} 73%|███████▎ | 1100/1500 [01:53<00:33, 11.78it/s]{'loss': 0.363, 'learning_rate': 2.6933333333333332e-05, 'epoch': 1.46} 77%|███████▋ | 1150/1500 [01:57<00:29, 11.76it/s]{'loss': 0.2832, 'learning_rate': 2.36e-05, 'epoch': 1.53} 80%|████████ | 1200/1500 [02:01<00:25, 11.75it/s]{'loss': 0.2635, 'learning_rate': 2.0266666666666667e-05, 'epoch': 1.59} 83%|████████▎ | 1250/1500 [02:05<00:20, 11.99it/s]{'loss': 0.3078, 'learning_rate': 1.6933333333333333e-05, 'epoch': 1.66} 87%|████████▋ | 1300/1500 [02:10<00:16, 11.96it/s]{'loss': 0.285, 'learning_rate': 1.3600000000000002e-05, 'epoch': 1.73} 90%|█████████ | 1350/1500 [02:14<00:12, 11.86it/s]{'loss': 0.2494, 'learning_rate': 1.0266666666666668e-05, 'epoch': 1.79} 93%|█████████▎| 1400/1500 [02:18<00:08, 11.92it/s]{'loss': 0.2882, 'learning_rate': 6.933333333333334e-06, 'epoch': 1.86} 97%|█████████▋| 1450/1500 [02:22<00:04, 11.92it/s]{'loss': 0.3269, 'learning_rate': 3.6e-06, 'epoch': 1.93} 100%|██████████| 1500/1500 [02:26<00:00, 11.93it/s]{'loss': 0.248, 'learning_rate': 2.6666666666666667e-07, 'epoch': 1.99} 100%|██████████| 1500/1500 [02:35<00:00, 9.66it/s] {'train_runtime': 155.2665, 'train_samples_per_second': 9.661, 'train_steps_per_second': 9.661, 'train_loss': 0.4075628541310628, 'epoch': 1.99} time cost 170.35878324508667 s Process finished with exit code 0
预测
chat.py交互问答的程序
import torch from transformers import AutoTokenizer, AutoModel from utils import load_lora_config from torch.cuda.amp import autocast checkpoint = "/home/yangxin/workspace/hom/chatglm-6b" # revision = "096f3de6b4959ce38bef7bb05f3129c931a3084e" revision = "942945df047dee66f653c68ae0e56655045f1741" model = AutoModel.from_pretrained(checkpoint, revision = revision, trust_remote_code=True) #加载预训练模型 tokenizer = AutoTokenizer.from_pretrained(checkpoint, revision = revision, trust_remote_code=True) model = load_lora_config(model) model.load_state_dict(torch.load(f"dbstd/chatglm-6b-lora.pt"), strict=False) #重载训练后的模型 model.half().cuda().eval() history = [] while True: print("[User]: ") msg = input() try: if msg.strip().upper() == "CLEAR": history = [] print("Ok.") continue elif msg.strip().upper() == "EXIT": history = [] print("Good Bye") break else: response, history = model.chat(tokenizer, msg, history=history) print("[中科软编程助手]: ") print(response) except Exception as e: print(str(e)) #调用方法:python chat.py
参考: