阿里云服务器ChatgGLM微调

任务二 ChatgGLM微调

文件模型下载在任务一中

cd ChatGLM-6B-main

这个文件夹是阿里云提供的可以直接对ChatGLM-6B进行微调的文件

其中ptuning文件夹中的train.sh就是训练的代码

PRE_SEQ_LEN=8  训练
LR=1e-2

CUDA_VISIBLE_DEVICES=0 python main.py \
    --do_train \
    --train_file AdvertiseGen_Simple/train.json \
    --validation_file AdvertiseGen_Simple/dev.json \
    --prompt_column content \
    --response_column summary \
    --overwrite_cache \
    --model_name_or_path chatglm-6b \
    --output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \
    --overwrite_output_dir \
    --max_source_length 64 \
    --max_target_length 64 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --predict_with_generate \
    --logging_steps 10 \
    --save_steps 6 \
    --learning_rate $LR \
    --pre_seq_len $PRE_SEQ_LEN \
    --quantization_bit 4 \
    --num_train_epochs 10


PRE_SEQ_LEN:预设序列长度,表示输入对话历史的最大长度。
LR:学习率,控制每次模型参数更新的步长。

--do_train:指定进行训练。
--train_file:指定训练数据集文件的路径。
--validation_file:指定验证数据集文件的路径。
--prompt_column:指定对话历史的列名。
--response_column:指定目标响应的列名。
--overwrite_cache:如果存在缓存文件,将其覆盖。
--model_name_or_path:指定要微调的模型名称或路径。
--output_dir:指定输出目录。
--overwrite_output_dir:如果输出目录已存在,将其覆盖。
--max_source_length:输入对话历史的最大长度限制。
--max_target_length:目标响应的最大长度限制。
--per_device_train_batch_size:训练时每个设备的批次大小。
--per_device_eval_batch_size:评估时每个设备的批次大小。
--gradient_accumulation_steps:梯度累积的步数。
--predict_with_generate:预测时使用生成模式。
--logging_steps:每隔多少步打印一次日志。
--save_steps:每隔多少步保存一次模型。
--learning_rate:学习率。
--pre_seq_len:预设序列长度。
--quantization_bit:量化位数。
--num_train_epochs:训练的总轮数。

进入/mnt/workspace/ChatGLM-6B-main/ptuning>目录下

开始训练

bash train.sh

在这里插入图片描述

训练五次的结果

训练之后开始模型推理

bash evaluate.sh

推理的结果

在这里插入图片描述

predict_bleu-4:BLEU-4(Bilingual Evaluation Understudy)是一种机器翻译质量评估指标,取值范围为0到1,越接近1表示机器翻译结果与参考答案的相似度越高。
predict_rouge-1:ROUGE-1是一种评估自动生成摘要的质量指标,通过比较生成的摘要和参考摘要中的unigram(单个词)重叠情况进行评估,结果表示重叠比例。
predict_rouge-2:ROUGE-2是ROUGE-1的扩展,比较生成的摘要和参考摘要中的bigram(两个连续词)重叠情况进行评估。
predict_rouge-L:ROUGE-L是一种基于最长公共子序列的评估指标,用于评估生成的摘要和参考摘要之间的相似度。
predict_runtime:表示模型运行预测的时间,以时:分:秒的格式进行显示。
predict_samples:表示预测样本的数量。
predict_samples_per_second:每秒钟处理的预测样本数量。
predict_steps_per_second:每秒钟处理的预测步骤数量。

然后回到ChatGLM-6B-main文件夹下运行web_demo.py

需要注意一下,因为ChatGLM-6B是从git获取的源代码,部署时使用,而ChatGLM-6B-main是从阿里云获得的微调使用的代码,其中都有web_demo.py,但是其内容是不同的

微调后:

在这里插入图片描述
在这里插入图片描述

与微调训练之前对比
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

可以看出来 仅仅是5个epoch的微调训练就可以让ChatGLM发生变化

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值