大家好,我是刘一手,专注于CV算法和多模态大模型在类教育场景的实际应用。
今天给大家带来《AI多模态教程:从0到1搭建VisualGLM图文大模型案例》
——————————————————————————————————————
一、模型介绍
开源多模态模型:VisualGLM-6B 是一个开源的对话模型,具备处理中英文对话和图像的能力**。
参数规模:模型拥有高达78亿参数,提供强大的语言和视觉处理能力。
语言支持:专门设计用于中英文对话,通过BLIP2-Qformer技术实现视觉与语言的有效整合。
预训练数据集:在CogView数据集上进行预训练,包含30M中文和300M英文图文对,确保语言处理的均衡。
微调优化:经过微调,模型能够生成更加贴合人类偏好的对话回答。
训练工具库:利用SwissArmyTransformer库进行训练,支持模型的灵活修改和高效微调。
部署便捷性:通过模型量化技术,可以在消费级显卡上部署,最低仅需8GB显存。
应用场景:模型适用于图像描述和知识问答任务,展示其在视觉和语言理解的综合应用能力。
二、仓库结构
2.1 克隆Github仓库
克隆命令:
git clone https://github.com/THUDM/VisualGLM-6B.git
如果读者访问不了外网,也可以从下面的云盘下载(见文末)。
2.2 仓库结构
三、环境安装
我使用的环境:(推荐租用AutoDL或者恒源云的云服务器,显卡显存在24G及以上就可以)
系统:Ubuntu22.04
CUDA驱动版本:11.7
显卡显存:RTX 3090 Ti 24GB
Python版本:3.8
VisualGLM模型的环境中有几个非常重要的依赖包,对版本有要求,版本不同可能会有各种报错,经过反复测试,下面的各版本可以正常运行最新代码(代码更新为2024年3月)
SwissArmyTransformer0.4.5
bitsandbytes0.39.0
transformers4.33.1
torch1.13.1
torchvision==0.14.1
这里一手也给大家准备了完整环境的压缩包(百度云盘链接见文末),使用方法:先在miniconda或者conda的envs目录新建一个文件夹,比如visualglm_env,然后进入这个文件夹内,把压缩包复制进去,直接解压就可以使用,免安装:
cd /home/miniconda3/envs
mkdir visualglm_env
tar -xzf visualglm_env.tar.gz -C visualglm_env
四、预训练权重下载
预训练模型是依靠来自于 CogView 数据集的30M高质量中文图文对,与300M经过筛选的英文图文对进行预训练,中英文权重相同。该训练方式较好地将视觉信息对齐到ChatGLM的语义空间;之后的微调阶段,模型在长视觉问答数据上训练,以生成符合人类偏好的答案。
官方提供了两种预训练模型:基于Huggingface的权重和基于 SwissArmyTransformer(简称sat) 的权重。
区别:Huggingface可以基于命令行和网页进行推理,但不可以用于训练;sat权重可以基于命令行和网页进行推理,可以当前预训练权重进行微调。这里推荐下载sat权重。
下载方法:
(1)官方推荐的方法
如果使用Huggingface transformers库调用模型,可以通过如下代码(其中图像路径为本地路径)自动下载权重文件。
PS:实际上的下载地址:https://huggingface.co/THUDM/visualglm-6b
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True).half().cuda()
image_path = "your image path"
response, history = model.chat(tokenizer, image_path, "描述这张图片。", history=[])
print(response)
response, history = model.chat(tokenizer, image_path, "这张图片可能是在什么场所拍摄的?", history=history)
print(response)
如果使用SwissArmyTransformer库调用模型,方法类似,可以使用环境变量SAT_HOME决定模型下载位置。在本仓库目录下:
import argparse
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
from model import chat, VisualGLMModel
model, model_args = VisualGLMModel.from_pretrained('visualglm-6b', args=argparse.Namespace(fp16=True, skip_init=True))
from sat.model.mixins import CachedAutoregressiveMixin
model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
image_path = "your image path or URL"
response, history, cache_image = chat(image_path, model, tokenizer, "描述这张图片。", history=[])
print(response)
response, history, cache_image = chat(None, model, tokenizer, "这张图片可能是在什么场所拍摄的?", history=history, image=cache_image)
print(response)
但是以上方法都不推荐!!因为你会因为网速慢、连接不上官网、文件大(~16G)下载速度慢等原因非常崩溃。因为一手推荐用第二种方法。
(2)通过网盘下载
下载链接见文末
除了预训练权重之外,我们还需要依赖ChatGLM来进行图文对话,这里不需要下载整个ChatGLM模型,只需要5个跟tokenizer相关的文件:
上面两个文件下载完成后解压到项目根目录:
五、预训练权重推理
我们先使用下载好的SAT权重进行命令行和网页端的推理,测试整体的环境安装是否正确,以及整个数据加载–>推理的流程能否跑通。
5.1 命令行推理
示例代码cli_demo.py:
代码1~2行:增加设置指定的GPU编号,可以根据自己显卡情况修改,如果只有一张卡,可以将1改为0;
代码22行:在quant参数里面设置量化大小,可以选择量化为4bit或者8bit;
代码23行:在from_pretrained参数里面修改为自己本地的visualglm-6b预训练权重路径;
代码50行:在from_pretrained函数里面修改为自己本地的chatglm tokenizer目录路径;
import os # 导入操作系统接口模块
os.environ['CUDA_VISIBLE_DEVICES'] = "1" # 设置环境变量,指定使用的第2个CUDA设备,从0开始编号
import sys # 导入系统模块,用于访问与Python解释器相关的变量和函数
import torch # 导入PyTorch深度学习框架
import argparse # 导入命令行参数解析模块
from transformers import AutoTokenizer # 从transformers库导入自动分词器
from sat.model.mixins import CachedAutoregressiveMixin # 从sat库导入自动回归混合类
from sat.quantization.kernels import quantize # 从sat库导入量化函数
from model import chat # 从model模块导入chat函数
from sat.model import AutoModel # 从sat.model模块导入AutoModel
def main():
parser = argparse.ArgumentParser() # 创建命令行参数解析器
# 添加命令行参数
parser.add_argument("--max_length", type=int, default=2048, help='max length of the total sequence') # 最大序列长度
parser.add_argument("--top_p", type=float, default=0.4, help='top p for nucleus sampling') # 核采样的top p值
parser.add_argument("--top_k", type=int, default=100, help='top k for top k sampling') # top k采样的top k值
parser.add_argument("--temperature", type=float, default=.8, help='temperature for sampling') # 采样温度
parser.add_argument("--english", action='store_true', help='only output English') # 是否只输出英文
parser.add_argument("--quant", choices=[8, 4], type=int, default=None, help='quantization bits') # 量化位数
parser.add_argument("--from_pretrained", type=str, default="./visualglm-6b",
help='pretrained ckpt') # 预训练模型路径
parser.add_argument("--prompt_zh", type=str, default="描述这张图片。",
help='Chinese prompt for the first round') # 中文提示语
parser.add_argument("--prompt_en", type=str, default="Describe the image.",
help='English prompt for the first round') # 英文提示语
args = parser.parse_args() # 解析命令行参数
# 加载模型
model, model_args = AutoModel.from_pretrained( # 使用from_pretrained方法加载预训练模型
args.from_pretrained,
args=argparse.Namespace(
fp16=True, # 是否使用半精度浮点数
skip_init=True, # 是否跳过初始化
use_gpu_initialization=True if (torch.cuda.is_available() and args.quant is None) else False, # 是否使用GPU初始化
device='cuda' if (torch.cuda.is_available() and args.quant is None) else 'cpu', # 设备选择
)
)
model = model.eval() # 设置模型为评估模式
if args.quant: # 如果指定了量化位数
quantize(model, args.quant) # 对模型进行量化
if torch.cuda.is_available(): # 如果CUDA可用
model = model.cuda() # 将模型移动到GPU
model.add_mixin('auto-regressive', CachedAutoregressiveMixin()) # 给模型添加自动回归混合类
tokenizer = AutoTokenizer.from_pretrained("./chatglm/", trust_remote_code=True) # 加载分词器
if not args.english: # 如果不是英文模式
print(
'欢迎使用 VisualGLM-6B 模型,输入图像URL或本地路径读图,继续输入内容对话,clear 重新开始,stop 终止程序') # 打印使用说明
else: # 英文模式
print(
'Welcome to VisualGLM-6B model. Enter an image URL or local file path to load an image. Continue '
'inputting text to engage in a conversation. Type "clear" to start over, or "stop" to end the program.')
# 打印英文使用说明
with torch.no_grad(): # 禁用梯度计算
while True: # 进入主循环
history = None # 初始化历史对话记录
cache_image = None # 初始化缓存的图像
if not args.english: # 如果不是英文模式
image_path = input("请输入图像路径或URL(回车进入纯文本对话): ") # 输入图像路径或URL
else: # 英文模式
image_path = input(
"Please enter the image path or URL (press Enter for plain text conversation): ") # 输入图像路径或URL
if image_path == 'stop': # 如果输入stop
break # 退出循环
if len(image_path) > 0: # 如果输入了图像路径
query = args.prompt_en if args.english else args.prompt_zh # 设置查询提示语
else: # 如果没有输入图像路径,进入纯文本对话
if not args.english:
query = input("用户:") # 输入中文用户对话
else:
query = input("User: ") # 输入英文用户对话
while True: # 进入对话循环
if query == "clear": # 如果用户输入clear
break # 重置对话
if query == "stop": # 如果用户输入stop
sys.exit(0) # 退出程序
try: # 尝试执行对话
response, history, cache_image = chat( # 调用chat函数进行对话
image_path,
model,
tokenizer,
query,
history=history,
image=cache_image,
max_length=args.max_length,
top_p=args.top_p,
temperature=args.temperature,
top_k=args.top_k,
english=args.english,
invalid_slices=[slice(63823, 130000)] if args.english else []
)
except Exception as e: # 如果发生异常
print(e) # 打印异常信息
break # 退出循环
sep = 'A:' if args.english else '答:' # 设置分隔符
print("VisualGLM-6B:" + response.split(sep)[-1].strip()) # 打印模型的回复
image_path = None # 重置图像路径
if not args.english: # 如果不是英文模式
query = input("用户:") # 输入中文用户对话
else: # 英文模式
query = input("User: ") # 输入英文用户对话
if __name__ == "__main__": # 如果是主模块
main() # 调用主函数
运行之后可以进行文字对话,也可以输入图片路径进行图像理解和对话。
文字对话:
图文理解:
经过测试在显卡型号为3090TI的设备下,不同量化后的模型所占显存大小如下,单位GB:
无量化 | INT8 | INT4 |
---|---|---|
19.8 | 14.3 | 8.8 |
5.2 网页版推理
示例代码web_demo.py:
代码1~2行:增加设置指定的GPU编号,可以根据自己显卡情况修改,如果只有一张卡,可以将1改为0;
代码131行和159行:修改为自己本地的chatglm tokenizer目录路径;
代码205行:在quant参数里面设置量化大小,可以选择量化为4bit或者8bit;
代码206行:是否共享应用,若为True,则会生成一个公网链接,可分享给其他人使用;
代码207行:在from_pretrained参数里面修改为自己本地的visualglm-6b预训练权重路径;
import os # 导入操作系统接口模块
os.environ['CUDA_VISIBLE_DEVICES'] = "1" # 设置环境变量,指定使用的第一个CUDA设备
import argparse # 导入命令行参数解析模块
import gradio as gr # 导入Gradio库,用于创建Web应用界面
from PIL import Image # 从PIL库导入Image类,用于图像处理
from model import is_chinese, generate_input # 从model模块导入is_chinese和generate_input函数
import torch # 导入PyTorch深度学习框架
from transformers import AutoTokenizer # 从transformers库导入自动分词器
from finetune_visualglm import FineTuneVisualGLMModel
from model import chat # 从model模块导入chat函数
from sat.model import AutoModel # 从sat.model模块导入AutoModel
from sat.model.mixins import CachedAutoregressiveMixin # 从sat.model.mixins模块导入CachedAutoregressiveMixin
from sat.quantization.kernels import quantize # 从sat.quantization.kernels模块导入quantize函数
# 定义一个函数,用于根据输入文本和图像生成文本
def generate_text_with_image(input_text, image, history=[], request_data=dict(), is_zh=True):
# 设置默认的输入参数
input_para = {
"max_length": 2048,
"min_length": 50,
"temperature": 0.8,
"top_p": 0.4,
"top_k": 100,
"repetition_penalty": 1.2
}
# 更新输入参数
input_para.update(request_data)
# 生成输入数据
input_data = generate_input(input_text, image, history, input_para, image_is_encoded=False)
input_image, gen_kwargs = input_data['input_image'], input_data['gen_kwargs']
# 在不计算梯度的情况下执行模型
with torch.no_grad():
# 使用chat函数生成回答
answer, history, _ = chat(None, model, tokenizer, input_text, history=history, image=input_image, \
max_length=gen_kwargs['max_length'], top_p=gen_kwargs['top_p'], \
top_k=gen_kwargs['top_k'], temperature=gen_kwargs['temperature'], english=not is_zh)
return answer
# 定义一个函数,用于处理模型请求
def request_model(input_text, temperature, top_p, image_prompt, result_previous):
# 处理历史记录
result_text = [(ele[0], ele[1]) for ele in result_previous]
# 清理历史记录
for i in range(len(result_text) - 1, -1, -1):
if result_text[i][0] == "" or result_text[i][1] == "":
del result_text[i]
print(f"history {result_text}")
# 判断输入文本是否为中文
is_zh = is_chinese(input_text)
# 如果没有图像提示,给出错误信息
if image_prompt is None:
if is_zh:
result_text.append((input_text, '图片为空!请上传图片并重试。'))
else:
result_text.append((input_text, 'Image empty! Please upload a image and retry.'))
return input_text, result_text
# 如果输入文本为空,给出错误信息
elif input_text == "":
result_text.append((input_text, 'Text empty! Please enter text and retry.'))
return "", result_text
# 设置请求参数
request_para = {"temperature": temperature, "top_p": top_p}
# 打开图像文件
image = Image.open(image_prompt)
try:
# 生成文本
answer = generate_text_with_image(input_text, image, result_text.copy(), request_para, is_zh)
except Exception as e:
print(f"error: {e}")
if is_zh:
result_text.append((input_text, '超时!请稍等几分钟再重试。'))
else:
result_text.append((input_text, 'Timeout! Please wait a few minutes and retry.'))
return "", result_text
# 添加新的对话到历史记录
result_text.append((input_text, answer))
print(result_text)
return "", result_text
# 设置应用描述
DESCRIPTION = '''# <a href="">Visual-GLM</a>'''
# 设置维护通知
MAINTENANCE_NOTICE1 = 'Hint 1: If the app report "Something went wrong, connection error out", please turn off your proxy and retry.\nHint 2: If you upload a large size of image like 10MB, it may take some time to upload and process. Please be patient and wait.'
MAINTENANCE_NOTICE2 = '提示1: 如果应用报了“Something went wrong, connection error out”的错误,请关闭代理并重试。\n提示2: 如果你上传了很大的图片,比如10MB大小,那将需要一些时间来上传和处理,请耐心等待。'
# 设置注释
NOTES = 'This app is adapted from <a href="https://github.com/THUDM/VisualGLM-6B">https://github.com/THUDM/VisualGLM-6B</a>. It would be recommended to check out the repo if you want to see the detail of our model and training process.'
# 定义清除函数
def clear_fn(value):
return "", [("", "Hi, What do you want to know about this image?")], None
# 定义清除函数2
def clear_fn2(value):
return [("", "Hi, What do you want to know about this image?")]
# 获取模型的函数
def get_model(args):
global model, tokenizer
# 加载模型
model, model_args = AutoModel.from_pretrained(
args.from_pretrained,
args=argparse.Namespace(
fp16=True,
skip_init=True,
use_gpu_initialization=True if (torch.cuda.is_available() and args.quant is None) else False,
device='cuda' if (torch.cuda.is_available() and args.quant is None) else 'cpu',
))
model = model.eval()
if args.quant:
quantize(model.transformer, args.quant)
if torch.cuda.is_available():
model = model.cuda()
model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
tokenizer = AutoTokenizer.from_pretrained(
"./chatglm",
trust_remote_code=True)
# tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
return model, tokenizer
# 主函数
def main(args):
global model, tokenizer
# 加载模型
model, model_args = AutoModel.from_pretrained(
args.from_pretrained,
args=argparse.Namespace(
fp16=True,
skip_init=True,
use_gpu_initialization=True if (torch.cuda.is_available() and args.quant is None) else False,
device='cuda' if (torch.cuda.is_available() and args.quant is None) else 'cpu',
))
model = model.eval()
if args.quant:
quantize(model.transformer, args.quant)
if torch.cuda.is_available():
model = model.cuda()
model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
tokenizer = AutoTokenizer.from_pretrained("./chatglm", trust_remote_code=True)
# 使用Gradio创建界面
with gr.Blocks(css='style.css') as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=4.5):
with gr.Group():
input_text = gr.Textbox(label='Input Text',
placeholder='Please enter text prompt below and press ENTER.')
with gr.Row():
run_button = gr.Button('Generate')
clear_button = gr.Button('Clear')
image_prompt = gr.Image(type="filepath", label="Image Prompt", value=None)
with gr.Row():
temperature = gr.Slider(maximum=1, value=0.8, minimum=0, label='Temperature')
top_p = gr.Slider(maximum=1, value=0.4, minimum=0, label='Top P')
with gr.Group():
with gr.Row():
maintenance_notice = gr.Markdown(MAINTENANCE_NOTICE1)
with gr.Column(scale=5.5):
result_text = gr.components.Chatbot(label='Multi-round conversation History', value=[
("", "Hi, What do you want to know about this image?")]).style(height=550)
gr.Markdown(NOTES)
# 设置Gradio版本信息
print(gr.__version__)
# 设置按钮点击事件
run_button.click(fn=request_model, inputs=[input_text, temperature, top_p, image_prompt, result_text],
outputs=[input_text, result_text])
input_text.submit(fn=request_model, inputs=[input_text, temperature, top_p, image_prompt, result_text],
outputs=[input_text, result_text])
clear_button.click(fn=clear_fn, inputs=clear_button, outputs=[input_text, result_text, image_prompt])
image_prompt.upload(fn=clear_fn2, inputs=clear_button, outputs=[result_text])
image_prompt.clear(fn=clear_fn2, inputs=clear_button, outputs=[result_text])
# 启动Gradio应用
demo.queue(concurrency_count=10) # 设置并发请求数
demo.launch(share=args.share) # 启动应用,允许共享
# 主程序入口
if __name__ == '__main__':
parser = argparse.ArgumentParser() # 设置命令行参数解析器
parser.add_argument("--quant", choices=[8, 4], type=int, default=4) # 设置量化位数参数
parser.add_argument("--share", action="store_true", default=True) # 是否共享应用
parser.add_argument("--from_pretrained", type=str, default="./visualglm-6b", help='pretrained ckpt') # 预训练模型路径
args = parser.parse_args() # 解析命令行参数
main(args) # 调用主函数
运行之后进入网址,可以输入图片路径和问题进行图像理解对话:
六、模型微调训练方法
模型微调步骤:按照官方样例准备好数据集、设置微调参数、启动微调训练、使用微调权重推理
6.1 数据准备
官方提供了一个微调数据的格式,解压fewshot-data.zip得到如下结构:
也就是我们需要准备好需要微调的图片+一个dataset.json文件,后者里面存放的是图文对:
可以看到dataset.json的内容是一个大列表,列表里面的每一个元素是字典格式:img表示图片的路径、prompt表示提问的文本,label表示回答。这个示例数据主要关注图片的背景内容,所以回复也是背景是xxx的形式。实际上在创建自己的数据集时,prompt和label可以修改为自己想要实现的问题-答案对。
6.2 配置微调脚本
目前支持三种方式的微调:
LoRA:样例中为ChatGLM模型的第0层和第14层加入了rank=10的LoRA微调,可以根据具体情景和数据量调整–layer_range和–lora_rank参数。
QLoRA:如果资源有限,可以考虑使用bash
finetune/finetune_visualglm_qlora.sh,QLoRA将ChatGLM的线性层进行了4-bit量化,只需要9.8GB显存即可微调。
P-tuning:可以将–use_lora替换为–use_ptuning,不过不推荐使用,除非模型应用场景非常固定。
注意微调需要安装deepspeed库,目前本流程仅支持linux系统。
下面是在单卡上使用LoRA在示例数据集fewshot-data的微调脚本:finetune_visualglm.sh,batch-size=2时训练需要18GB显存。
#! /bin/bash
NUM_WORKERS=1
NUM_GPUS_PER_WORKER=8
MP_SIZE=1
script_path=$(realpath $0)
script_dir=$(dirname $script_path)
main_dir=$(dirname $script_dir)
MODEL_TYPE="visualglm-6b"
MODEL_ARGS="--max_source_length 64 \
--max_target_length 256 \
--lora_rank 10 \
--layer_range 0 14 \
--pre_seq_len 4"
# OPTIONS_SAT="SAT_HOME=$1" #"SAT_HOME=/raid/dm/sat_models"
OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
HOST_FILE_PATH="hostfile"
HOST_FILE_PATH="hostfile_single"
train_data="./fewshot-data/dataset.json"
eval_data="./fewshot-data/dataset.json"
gpt_options=" \
--experiment-name finetune-$MODEL_TYPE \
--model-parallel-size ${MP_SIZE} \
--mode finetune \
--train-iters 1000 \
--resume-dataloader \
$MODEL_ARGS \
--train-data ${train_data} \
--valid-data ${eval_data} \
--distributed-backend nccl \
--lr-decay-style cosine \
--warmup .02 \
--checkpoint-activations \
--save-interval 300 \
--eval-interval 100 \
--save "./checkpoints" \
--split 1 \
--eval-iters 10 \
--eval-batch-size 1 \
--zero-stage 1 \
--lr 0.0001 \
--batch-size 2 \
--skip-init \
--fp16 \
--use_lora
"
run_cmd="${OPTIONS_NCCL} ${OPTIONS_SAT} deepspeed --master_port 16666 --include localhost:1 --hostfile ${HOST_FILE_PATH} finetune_visualglm.py ${gpt_options}"
echo ${run_cmd}
eval ${run_cmd}
set +x
微调参数说明:
这个脚本是一个用于配置和启动VisualGLM训练任务的Bash脚本,它包含了多个参数,用于设置训练环境和训练过程的各个方面。下面是对脚本中参数的详细说明:
NUM_WORKERS: 指定训练过程中使用的worker数量。每个worker可以运行在不同的机器上。
NUM_GPUS_PER_WORKER: 每个worker使用的GPU数量。这决定了每个训练进程可以利用的GPU资源。 MP_SIZE:
模型并行大小,即模型被分割成多少部分在不同的进程中并行处理。 script_path: 当前脚本的绝对路径。 script_dir:
当前脚本所在的目录。 main_dir: 脚本所在目录的上一级目录,通常用于存放项目的主要文件。 MODEL_TYPE:
使用的模型类型,这里是visualglm-6b,表示一个特定的视觉语言模型。 MODEL_ARGS: 一系列与模型相关的参数,包括:
–max_source_length 64: 输入序列的最大长度。
–max_target_length 256: 输出序列的最大长度。
–lora_rank 10: LoRA(局部重参数化)的秩。
–layer_range 0 14: 参与并行处理的模型层的范围。
–pre_seq_len 4: 预序列的长度。 OPTIONS_SAT: 环境变量设置,这里被注释掉了,如果取消注释,它将设置SAT_HOME环境变量。 OPTIONS_NCCL: 一系列与NVIDIA
Collective Communications Library (NCCL) 相关的环境变量设置,用于优化GPU之间的通信。
HOST_FILE_PATH: 指定主机文件的路径,该文件包含了参与训练的所有机器的列表。 train_data 和 eval_data:
分别指定训练数据和评估数据的路径。 gpt_options: 一系列用于配置训练任务的参数,包括:
–experiment-name: 实验名称。
–model-parallel-size: 模型并行大小。
–mode: 训练模式,这里是finetune。
–train-iters: 训练迭代次数。
–resume-dataloader: 是否恢复数据加载器的状态。
–train-data 和 --valid-data: 分别指定训练和验证数据的路径。
–distributed-backend: 分布式训练的后端,这里是nccl。
–lr-decay-style: 学习率衰减方式。
–warmup: 学习率预热的比例。
–save-interval 和 --eval-interval: 分别指定保存和评估的间隔。
–save: 模型保存的路径。
–split: 数据集分割的比例。
–eval-iters 和 --eval-batch-size: 分别指定评估的迭代次数和批次大小。
–zero-stage: 零优化的阶段。
–lr: 学习率。
–batch-size: 批次大小。
–skip-init: 是否跳过模型初始化。
–fp16: 是否使用半精度浮点数。
–use_lora: 是否使用LoRA技术。 run_cmd: 构建用于启动训练任务的命令字符串,包括NCCL和SAT的环境变量设置,以及deepspeed训练脚本的调用。 echo
${run_cmd}: 打印构建的命令字符串。 eval ${run_cmd}: 执行构建的命令字符串,启动训练任务。 set +x:
在脚本的最后,这行命令用于关闭xtrace(调试模式),这样在执行脚本时不会打印出所有的命令。
6.3 训练过程
在终端进入项目根目录,使用下面的命令启动训练:
bash finetune/finetune_visualglm.sh
在微调训练过程中会打印当前训练步数、学习率、损失值:
在eval-interval步数后开始验证,并打印在验证集上的损失值和PPL(PPL全称Perplexity ,指困惑度,用来衡量语言模型好坏的指标。简单说,perplexity值刻画的是语言模型预测一个语言样本的能力。在一个测试集上得到的perplexity值越低,说明建模效果越好):
在save-interval步数后进行保存,并打印保存路径:
在train-iters步数后结束训练,保存最后的权重并关闭GPU连接:
在上述过程中,还会通过tensorboard将更多日志信息保存在runs文件夹,使用下面的命令打开查看:
tensorboard --logdir==runs --port 6007 --bind_all
这里能看到训练损失、验证损失、验证指标的变化情况(PPL值越小越好)
6.4 微调权重推理
微调后的权重存放在checkpoints目录下,加载方式web_demo.py相同,只需要把主函数中的from_pretrained参数修改为自己本地的微调权重路径:
.....
# 主程序入口
if __name__ == '__main__':
parser = argparse.ArgumentParser() # 设置命令行参数解析器
parser.add_argument("--quant", choices=[8, 4], type=int, default=None) # 设置量化位数参数
parser.add_argument("--share", action="store_true", default=True) # 是否共享应用
parser.add_argument("--from_pretrained", type=str, default="./checkpoints/finetune-visualglm-6b-02-27-14-41", help='pretrained ckpt') # 预训练模型路径
args = parser.parse_args() # 解析命令行参数
main(args) # 调用主函数
......
运行后上传图片进行推理:
七、模型部署
这里为大家提供一种基于Flask API的方式部署VisualGLM服务。
步骤一:主服务功能配置;
步骤二:API接口调用;
好处:调用速度快,可以灵活地在不同场景下使用。
7.1 实现主服务
主服务是指基于Flask构建一个后台运行的算法服务,使用可以直接使用API的方式调用。
主服务代码image_caption_server.py:
代码16行:from_pretrained参数可以设置为预训练权重路径或者微调权重路径; loguru
用来记录推理结果或者记录程序异常日志,方便后期进行bug排查;
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "1"
import time
import torch
import argparse
from flask import Flask, request
from flask_cors import cross_origin
from loguru import logger
from web_demo import get_model
from model import is_chinese, generate_input, chat
parser = argparse.ArgumentParser()
parser.add_argument("--quant", choices=[8, 4], type=int, default=None)
parser.add_argument("--share", action="store_true", default=False)
parser.add_argument("--from_pretrained", type=str, default="checkpoints/finetune-visualglm-6b-02-27-14-41",
help='pretrained ckpt')
args = parser.parse_args()
model, tokenizer = get_model(args)
ti_ = time.localtime()
date_ = f"{ti_[0]}_{ti_[1]}_{ti_[2]}"
current_time_path = f"./logs/runtime/{date_}"
os.makedirs(current_time_path, exist_ok=True)
logger.remove(handler_id=None)
logger.add(os.path.join(current_time_path, "runtime_{time}.log"), rotation='1 day')
app = Flask(__name__)
@app.route('/image_caption', methods=['POST'])
@logger.catch()
@cross_origin()
def image_caption():
if request.method == "POST":
try:
print("Start to process request")
request_data = request.get_json()
input_text, input_image_encoded, history = request_data['text'], request_data['image'], request_data[
'history']
input_para = {
"max_length": 2048,
"min_length": 50,
"temperature": 0.8,
"top_p": 0.4,
"top_k": 100,
"repetition_penalty": 1.2
}
input_para.update(request_data)
is_zh = is_chinese(input_text)
input_data = generate_input(input_text, input_image_encoded, history, input_para)
input_image, gen_kwargs = input_data['input_image'], input_data['gen_kwargs']
with torch.no_grad():
answer, history, _ = chat(None, model, tokenizer, input_text, history=history, image=input_image, \
max_length=gen_kwargs['max_length'], top_p=gen_kwargs['top_p'], \
top_k=gen_kwargs['top_k'], temperature=gen_kwargs['temperature'],
english=not is_zh)
except Exception as e:
logger.info(e)
answer = "暂无相关描述,请检查图像内容!"
response = {
"result": answer,
"history": history,
"status": 200,
}
logger.info(answer)
return response
if __name__ == '__main__':
app.run(debug=False, host='0.0.0.0', port=9500)
7.2 实现API接口调用
主服务相当于随时准备推理的模块,我们还需要通过API调用主服务的功能。
API调用代码image_caption_api.py:
import base64
import requests
def visual_api(imgb64):
# 请求数据
data = {
"text": "请详细描述这张图片的内容",
"image": imgb64,
"history": []
}
# 发起POST请求
response = requests.post('http://127.0.0.1:9500/image_caption', json=data)
result = response.json()
return result
if __name__ == '__main__':
one_img = "examples/3.jpeg"
imgbase64 = base64.b64encode(open(one_img, 'rb').read()).rstrip().decode('utf-8')
res = visual_api(imgbase64)
print("回答:", res['result'])
首先运行主服务代码image_caption_server.py,运行成功后会显示调用的链接:
然后在API调用程序image_caption_api.py里面设置图片路径和提问的text,text默认是"请详细描述这张图片的内容",运行程序就可以得到推理结果:
八、网盘链接汇总
1、VisualGLM代码文件:
链接:https://pan.baidu.com/s/15GfpCubvSgnkbSNtVnrKbQ?pwd=xy2t2、预训练权重文件
链接:https://pan.baidu.com/s/1MLMIb42AOJnqa4R1so6dxg?pwd=23uf3、chatglm tokenizer文件
链接:https://pan.baidu.com/s/1kq25Jdab5_p1umDgJaSdyQ?pwd=9ir84、python虚拟环境压缩包
链接:https://pan.baidu.com/s/1-7VA9s2iwjFMyqSOZhiMgQ?pwd=u39c
九、常见错误及解决方法
1、运行web_demo.py报错:Permission denied: /tmp/gradio/xx
解决方法:更改/tmp/gradio目录权限为可读可写:sudo chmod -R 777 /tmp/gradio/2、运行cli_demo.py或者web_demo.py报错:model_class FineTuneVisualGLMModel not
found
解决方法:from finetune_visualglm import FineTuneVisualGLMModel3、ChatGLMTokenizer’ object has no attribute 'tokenizer
解决方法:重装transformers到4.33.1
写在后面
如果大家对多模态大模型感兴趣,可以扫码加群学习交流,二维码失效可以添加微信:lzz9527288