ChatGLM-6B微调推理实战


ChatGLM-6B是一个由清华大学和智谱AI联合研发的开源对话语言模型,它基于General Language Model(GLM)架构,具有62亿参数,并支持中英双语问答。结合模型量化技术,用户可以在消费级的显卡上进行本地部署。在INT4量化级别下,最低只需6GB显存即可运行。

运行环境:

⭐️环境安装:

使用 pip 安装依赖:`pip install -r requirements.txt`,其中 `transformers` 库版本推荐为 `4.27.1`,但理论上不低于 `4.23.1` 即可。 

⭐️代码调用方式:

>>> from transformers import AutoTokenizer, AutoModel
>>> tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
>>> model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
>>> model = model.eval()
>>> response, history = model.chat(tokenizer, "你好", history=[])
>>> print(response)
你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。
>>> response, history = model.chat(tokenizer, "晚上睡不着应该怎么办", history=history)
>>> print(response)
晚上睡不着可能会让你感到焦虑或不舒服,但以下是一些可以帮助你入睡的方法:

1. 制定规律的睡眠时间表:保持规律的睡眠时间表可以帮助你建立健康的睡眠习惯,使你更容易入睡。尽量在每天的相同时间上床,并在同一时间起床。
2. 创造一个舒适的睡眠环境:确保睡眠环境舒适,安静,黑暗且温度适宜。可以使用舒适的床上用品,并保持房间通风。
3. 放松身心:在睡前做些放松的活动,例如泡个热水澡,听些轻柔的音乐,阅读一些有趣的书籍等,有助于缓解紧张和焦虑,使你更容易入睡。
4. 避免饮用含有咖啡因的饮料:咖啡因是一种刺激性物质,会影响你的睡眠质量。尽量避免在睡前饮用含有咖啡因的饮料,例如咖啡,茶和可乐。
5. 避免在床上做与睡眠无关的事情:在床上做些与睡眠无关的事情,例如看电影,玩游戏或工作等,可能会干扰你的睡眠。
6. 尝试呼吸技巧:深呼吸是一种放松技巧,可以帮助你缓解紧张和焦虑,使你更容易入睡。试着慢慢吸气,保持几秒钟,然后缓慢呼气。

如果这些方法无法帮助你入睡,你可以考虑咨询医生或睡眠专家,寻求进一步的建议。
⭐️下载ChatGLM-6B

1️⃣Modelscope下载:ChatGLM下载         

2️⃣ 阿里云OSS存储:(加快下载速度)其他云也一样

import os
dsw_region = os.environ.get("dsw_region")
# 从环境变量中获取dsw_region的值
url_link = {
    "cn-shanghai": "https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/chatGLM/ChatGLM-6B-main.tar.gz",
    "cn-hangzhou": "https://atp-modelzoo.oss-cn-hangzhou.aliyuncs.com/release/tutorials/chatGLM/ChatGLM-6B-main.tar.gz",
    "cn-shenzhen": "https://atp-modelzoo-sz.oss-cn-shenzhen-internal.aliyuncs.com/release/tutorials/chatGLM/ChatGLM-6B-main.tar.gz",
    "cn-beijing": "https://atp-modelzoo-bj.oss-cn-beijing.aliyuncs.com/release/tutorials/chatGLM/ChatGLM-6B-main.tar.gz", 
}

🌈我离北京较近,选择cn- beijing

path = url_link["cn-beijing"]
os.environ['LINK_CHAT'] = path
# 选择cn-beijing区域对应的URL,并将其赋值给path变量
!wget $LINK_CHAT
# wget命令下载LINK_CHAT环境变量指向的URL处的文件
!tar -xvf ChatGLM-6B-main.tar.gz

⭐️安装依赖 
!cd ChatGLM-6B-main && pip install -r requirements.txt && \
pip install rouge_chinese nltk jieba datasets gradio==3.37.0

 

数据准备

数据文件为json文件。json文件中每条数据是一个字典,记录输入文本和输出文本。 

我们使用官方的数据集:(AdvertiseGen_Simple)

!cd ChatGLM-6B-main/ptuning && wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/chatGLM/AdvertiseGen_Simple.zip  && unzip AdvertiseGen_Simple.zip

  • content 字段包含了用于生成文本的输入信息(如服装的类型、版型、风格等),而 summary 字段则包含了与输入信息相对应的输出文本(即描述服装的文本)

微调模型

train.sh: 

PRE_SEQ_LEN=8
LR=1e-2

CUDA_VISIBLE_DEVICES=0 python main.py \
    --do_train \
    --train_file AdvertiseGen_Simple/train.json \
    --validation_file AdvertiseGen_Simple/dev.json \
    --prompt_column content \
    --response_column summary \
    --overwrite_cache \
    --model_name_or_path chatglm-6b \
    --output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \
    --overwrite_output_dir \
    --max_source_length 64 \
    --max_target_length 64 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --predict_with_generate \
    --logging_steps 10 \
    --save_steps 6 \
    --learning_rate $LR \
    --pre_seq_len $PRE_SEQ_LEN \
    --quantization_bit 4 \
    --num_train_epochs 1
  • CUDA_VISIBLE_DEVICES=0:指定使用编号为0的GPU设备进行训练。
  • --quantization_bit 4:指定量化位数,这里设置为4位。
!cd ChatGLM-6B-main/ptuning && bash train.sh

模型推理

evaluate.sh: 

PRE_SEQ_LEN=8
CHECKPOINT=adgen-chatglm-6b-pt-8-1e-2
STEP=6

CUDA_VISIBLE_DEVICES=0 python main.py \
    --do_predict \
    --validation_file AdvertiseGen_Simple/dev.json \
    --test_file AdvertiseGen_Simple/dev.json \
    --overwrite_cache \
    --prompt_column content \
    --response_column summary \
    --model_name_or_path ./output/$CHECKPOINT/checkpoint-$STEP  \
    --output_dir ./output/$CHECKPOINT \
    --overwrite_output_dir \
    --max_source_length 64 \
    --max_target_length 64 \
    --per_device_eval_batch_size 1 \
    --predict_with_generate \
    --pre_seq_len $PRE_SEQ_LEN \
    --quantization_bit 4
  •  --do_predict:表示要进行预测
  • --overwrite_cache:如果缓存文件已存在,则覆盖它们。

  • --overwrite_output_dir:如果输出目录已存在,则覆盖它。

!cd ChatGLM-6B-main/ptuning &&  bash evaluate.sh

🍺生成的结果:

🍺完成后可以启动web_demo.py 启动网页对话:

!cd ChatGLM-6B-main/ && python web_demo.py

一般我们下载大模型后,会有cli_demo.py和 web_demo.py:

cli_demo.py

  • 用户可以通过命令行参数与脚本交互,输入文本并获取模型的生成结果。

web_demo.py

  • 用户可以通过浏览器访问一个网页,输入文本并获取模型的生成结果。

💯访问本机(服务器)的端口7860就可以访问到 :

  • 28
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 22
    评论
评论 22
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小森( ﹡ˆoˆ﹡ )

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值