任务二 ChatgGLM微调
文件模型下载在任务一中
cd ChatGLM-6B-main
这个文件夹是阿里云提供的可以直接对ChatGLM-6B进行微调的文件
其中ptuning文件夹中的train.sh就是训练的代码
PRE_SEQ_LEN=8 训练
LR=1e-2
CUDA_VISIBLE_DEVICES=0 python main.py \
--do_train \
--train_file AdvertiseGen_Simple/train.json \
--validation_file AdvertiseGen_Simple/dev.json \
--prompt_column content \
--response_column summary \
--overwrite_cache \
--model_name_or_path chatglm-6b \
--output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \
--overwrite_output_dir \
--max_source_length 64 \
--max_target_length 64 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 16 \
--predict_with_generate \
--logging_steps 10 \
--save_steps 6 \
--learning_rate $LR \
--pre_seq_len $PRE_SEQ_LEN \
--quantization_bit 4 \
--num_train_epochs 10
PRE_SEQ_LEN:预设序列长度,表示输入对话历史的最大长度。
LR:学习率,控制每次模型参数更新的步长。
--do_train:指定进行训练。
--train_file:指定训练数据集文件的路径。
--validation_file:指定验证数据集文件的路径。
--prompt_column:指定对话历史的列名。
--response_column:指定目标响应的列名。
--overwrite_cache:如果存在缓存文件,将其覆盖。
--model_name_or_path:指定要微调的模型名称或路径。
--output_dir:指定输出目录。
--overwrite_output_dir:如果输出目录已存在,将其覆盖。
--max_source_length:输入对话历史的最大长度限制。
--max_target_length:目标响应的最大长度限制。
--per_device_train_batch_size:训练时每个设备的批次大小。
--per_device_eval_batch_size:评估时每个设备的批次大小。
--gradient_accumulation_steps:梯度累积的步数。
--predict_with_generate:预测时使用生成模式。
--logging_steps:每隔多少步打印一次日志。
--save_steps:每隔多少步保存一次模型。
--learning_rate:学习率。
--pre_seq_len:预设序列长度。
--quantization_bit:量化位数。
--num_train_epochs:训练的总轮数。
进入/mnt/workspace/ChatGLM-6B-main/ptuning>目录下
开始训练
bash train.sh
训练五次的结果
训练之后开始模型推理
bash evaluate.sh
推理的结果
predict_bleu-4:BLEU-4(Bilingual Evaluation Understudy)是一种机器翻译质量评估指标,取值范围为0到1,越接近1表示机器翻译结果与参考答案的相似度越高。
predict_rouge-1:ROUGE-1是一种评估自动生成摘要的质量指标,通过比较生成的摘要和参考摘要中的unigram(单个词)重叠情况进行评估,结果表示重叠比例。
predict_rouge-2:ROUGE-2是ROUGE-1的扩展,比较生成的摘要和参考摘要中的bigram(两个连续词)重叠情况进行评估。
predict_rouge-L:ROUGE-L是一种基于最长公共子序列的评估指标,用于评估生成的摘要和参考摘要之间的相似度。
predict_runtime:表示模型运行预测的时间,以时:分:秒的格式进行显示。
predict_samples:表示预测样本的数量。
predict_samples_per_second:每秒钟处理的预测样本数量。
predict_steps_per_second:每秒钟处理的预测步骤数量。
然后回到ChatGLM-6B-main文件夹下运行web_demo.py
需要注意一下,因为ChatGLM-6B是从git获取的源代码,部署时使用,而ChatGLM-6B-main是从阿里云获得的微调使用的代码,其中都有web_demo.py,但是其内容是不同的
微调后:
与微调训练之前对比
可以看出来 仅仅是5个epoch的微调训练就可以让ChatGLM发生变化