24G显存 微调 Llama3 图片理解多模态

XTuner 微调 Llama3 图片理解多模态

本次实验基于 Llama3-8B-Instruct 和 XTuner 团队预训练好的 Image Projector 微调自己的多模态图文理解模型 LLaVA。实验平台为InternStudio,实验所用的显存为24G。
在这里插入图片描述

环境、模型、数据准备

配置环境

先来配置相关环境。使用如下指令便可以安装好一个 python=3.10 pytorch=2.1.2+cu121 的基础环境了。

conda create -n llama3 python=3.10
conda activate llama3
conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=12.1 -c pytorch -c nvidia

接下来安装 XTuner。

cd ~
git clone -b v0.1.18 https://github.com/InternLM/XTuner
cd XTuner
pip install -e .[all]

最后 clone 本Llama3-Tutorial仓库。

cd ~
git clone https://github.com/SmartFlowAI/Llama3-Tutorial

模型准备

准备Llama3权重
在微调开始前,首先来准备 Llama3-8B-Instruct 模型权重。在InternStudio内部的share文件夹里边有模型权重。

mkdir -p ~/model
cd ~/model
ln -s /root/share/new_models/meta-llama/Meta-Llama-3-8B-Instruct .

接下来准备 Llava 所需要的 openai/clip-vit-large-patch14-336,权重,即 Visual Encoder 权重。

mkdir -p ~/model
cd ~/model
ln -s /root/share/new_models/openai/clip-vit-large-patch14-336 .

准备 Image Projector 权重。

mkdir -p ~/model
cd ~/model
ln -s /root/share/new_models/xtuner/llama3-llava-iter_2181.pth .

数据准备
使用过拟合的方式快速实现。可以执行以下代码:

cd ~
git clone https://github.com/InternLM/tutorial -b camp2
python ~/tutorial/xtuner/llava/llava_data/repeat.py \
  -i ~/tutorial/xtuner/llava/llava_data/unique_data.json \
  -o ~/tutorial/xtuner/llava/llava_data/repeated_data.json \
  -n 200

如果自己想构建数据集参考此处

微调过程

使用如下指令以启动训练:

xtuner train /root/Llama3-Tutorial/configs/llama3-llava/llava_llama3_8b_instruct_qlora_clip_vit_large_p14_336_lora_e1_finetune.py --work-dir /root/model/llama3_llava_pth --deepspeed deepspeed_zero2_offload

训练过程所需显存约为14170左右 MiB,训练所需时间为4小时48分钟。

05/09 11:24:43 - mmengine - INFO - Iter(train) [  10/1200]  lr: 5.1430e-05  eta: 4:48:31  time: 14.5475  data_time: 0.0383  memory: 14172  loss: 1.4195
05/09 11:26:59 - mmengine - INFO - Iter(train) [  20/1200]  lr: 1.0857e-04  eta: 4:36:40  time: 13.5897  data_time: 0.0298  memory: 14170  loss: 0.5519
05/09 11:29:13 - mmengine - INFO - Iter(train) [  30/1200]  lr: 1.6571e-04  eta: 4:30:23  time: 13.4613  data_time: 0.0273  memory: 14170  loss: 0.5242
05/09 11:31:29 - mmengine - INFO - Iter(train) [  40/1200]  lr: 2.0000e-04  eta: 4:26:41  time: 13.5791  data_time: 0.0361  memory: 14170  loss: 0.4559
05/09 11:33:52 - mmengine - INFO - Iter(train) [  50/1200]  lr: 1.9994e-04  eta: 4:26:19  time: 14.3003  data_time: 0.0397  memory: 14169  loss: 0.3685
05/09 11:36:12 - mmengine - INFO - Iter(train) [  60/1200]  lr: 1.9981e-04  eta: 4:24:22  time: 14.0092  data_time: 0.0411  memory: 14172  loss: 0.4149
05/09 11:38:31 - mmengine - INFO - Iter(train) [  70/1200]  lr: 1.9960e-04  eta: 4:22:07  time: 13.9411  data_time: 0.0617  memory: 14170  loss: 0.2311
05/09 11:40:50 - mmengine - INFO - Iter(train) [  80/1200]  lr: 1.9933e-04  eta: 4:19:34  time: 13.8193  data_time: 0.0326  memory: 14170  loss: 0.1266
05/09 11:43:11 - mmengine - INFO - Iter(train) [  90/1200]  lr: 1.9898e-04  eta: 4:17:38  time: 14.0893  data_time: 0.0308  memory: 14170  loss: 0.1708
05/09 11:45:30 - mmengine - INFO - Iter(train) [ 100/1200]  lr: 1.9856e-04  eta: 4:15:22  time: 13.9608  data_time: 0.0320  memory: 14169  loss: 0.1202
05/09 11:47:51 - mmengine - INFO - Iter(train) [ 110/1200]  lr: 1.9807e-04  eta: 4:13:21  time: 14.1097  data_time: 0.0408  memory: 14172  loss: 0.0798
05/09 11:50:13 - mmengine - INFO - Iter(train) [ 120/1200]  lr: 1.9750e-04  eta: 4:11:26  time: 14.2204  data_time: 0.0433  memory: 14170  loss: 0.2290
05/09 11:52:32 - mmengine - INFO - Iter(train) [ 130/1200]  lr: 1.9687e-04  eta: 4:08:59  time: 13.8805  data_time: 0.0462  memory: 14170  loss: 0.1079
05/09 11:54:49 - mmengine - INFO - Iter(train) [ 140/1200]  lr: 1.9616e-04  eta: 4:06:14  time: 13.6250  data_time: 0.0364  memory: 14169  loss: 0.0395
05/09 11:56:58 - mmengine - INFO - Iter(train) [ 150/1200]  lr: 1.9539e-04  eta: 4:02:43  time: 12.9145  data_time: 0.0451  memory: 14168  loss: 0.1383
05/09 11:59:17 - mmengine - INFO - Iter(train) [ 160/1200]  lr: 1.9454e-04  eta: 4:00:29  time: 13.9496  data_time: 0.0380  memory: 14172  loss: 0.1375
05/09 12:01:35 - mmengine - INFO - Iter(train) [ 170/1200]  lr: 1.9363e-04  eta: 3:58:05  time: 13.7808  data_time: 0.0380  memory: 14170  loss: 0.0263
05/09 12:03:51 - mmengine - INFO - Iter(train) [ 180/1200]  lr: 1.9264e-04  eta: 3:55:30  time: 13.5891  data_time: 0.0426  memory: 14170  loss: 0.2581
05/09 12:06:10 - mmengine - INFO - Iter(train) [ 190/1200]  lr: 1.9159e-04  eta: 3:53:16  time: 13.9399  data_time: 0.0498  memory: 14170  loss: 0.1230
05/09 12:08:27 - mmengine - INFO - Iter(train) [ 200/1200]  lr: 1.9048e-04  eta: 3:50:49  time: 13.6799  data_time: 0.0493  memory: 14169  loss: 0.0601
05/09 12:10:48 - mmengine - INFO - Iter(train) [ 210/1200]  lr: 1.8930e-04  eta: 3:48:41  time: 14.0703  data_time: 0.0408  memory: 14172  loss: 0.1540
05/09 12:13:11 - mmengine - INFO - Iter(train) [ 220/1200]  lr: 1.8805e-04  eta: 3:46:41  time: 14.2892  data_time: 0.0405  memory: 14170  loss: 0.1472
05/09 12:15:32 - mmengine - INFO - Iter(train) [ 230/1200]  lr: 1.8674e-04  eta: 3:44:35  time: 14.1704  data_time: 0.0360  memory: 14170  loss: 0.0342
05/09 12:17:53 - mmengine - INFO - Iter(train) [ 240/1200]  lr: 1.8536e-04  eta: 3:42:24  time: 14.1000  data_time: 0.0391  memory: 14170  loss: 0.0607
05/09 12:20:13 - mmengine - INFO - Iter(train) [ 250/1200]  lr: 1.8393e-04  eta: 3:40:06  time: 13.9302  data_time: 0.0382  memory: 14169  loss: 0.0634
05/09 12:22:35 - mmengine - INFO - Iter(train) [ 260/1200]  lr: 1.8243e-04  eta: 3:38:01  time: 14.2698  data_time: 0.0495  memory: 14172  loss: 0.0763
05/09 12:24:57 - mmengine - INFO - Iter(train) [ 270/1200]  lr: 1.8087e-04  eta: 3:35:48  time: 14.1202  data_time: 0.0336  memory: 14170  loss: 0.0760
05/09 12:27:27 - mmengine - INFO - Iter(train) [ 280/1200]  lr: 1.7925e-04  eta: 3:34:07  time: 15.0706  data_time: 0.0356  memory: 14170  loss: 0.0674
05/09 12:29:56 - mmengine - INFO - Iter(train) [ 290/1200]  lr: 1.7758e-04  eta: 3:32:15  time: 14.8495  data_time: 0.0328  memory: 14170  loss: 0.0619
05/09 12:32:17 - mmengine - INFO - Iter(train) [ 300/1200]  lr: 1.7585e-04  eta: 3:29:59  time: 14.1217  data_time: 0.0418  memory: 14169  loss: 0.0213
05/09 12:34:36 - mmengine - INFO - Iter(train) [ 310/1200]  lr: 1.7406e-04  eta: 3:27:36  time: 13.8896  data_time: 0.0345  memory: 14172  loss: 0.0649
05/09 12:37:04 - mmengine - INFO - Iter(train) [ 320/1200]  lr: 1.7222e-04  eta: 3:25:39  time: 14.8585  data_time: 0.0411  memory: 14172  loss: 0.0383
05/09 12:39:27 - mmengine - INFO - Iter(train) [ 330/1200]  lr: 1.7033e-04  eta: 3:23:26  time: 14.2799  data_time: 0.0332  memory: 14170  loss: 0.0712
05/09 12:41:52 - mmengine - INFO - Iter(train) [ 340/1200]  lr: 1.6838e-04  eta: 3:21:18  time: 14.5005  data_time: 0.0291  memory: 14170  loss: 0.0852
05/09 12:44:11 - mmengine - INFO - Iter(train) [ 350/1200]  lr: 1.6639e-04  eta: 3:18:54  time: 13.9209  data_time: 0.0301  memory: 14169  loss: 0.0176
05/09 12:46:40 - mmengine - INFO - Iter(train) [ 360/1200]  lr: 1.6435e-04  eta: 3:16:53  time: 14.8487  data_time: 0.0294  memory: 14172  loss: 0.0503
05/09 12:49:04 - mmengine - INFO - Iter(train) [ 370/1200]  lr: 1.6226e-04  eta: 3:14:39  time: 14.3601  data_time: 0.0289  memory: 14170  loss: 0.0551
05/09 12:51:29 - mmengine - INFO - Iter(train) [ 380/1200]  lr: 1.6012e-04  eta: 3:12:28  time: 14.5297  data_time: 0.0293  memory: 14170  loss: 0.0452
05/09 12:53:50 - mmengine - INFO - Iter(train) [ 390/1200]  lr: 1.5795e-04  eta: 3:10:08  time: 14.1204  data_time: 0.0276  memory: 14169  loss: 0.0306
05/09 12:56:10 - mmengine - INFO - Iter(train) [ 400/1200]  lr: 1.5573e-04  eta: 3:07:46  time: 14.0397  data_time: 0.0302  memory: 14169  loss: 0.0028
05/09 12:58:31 - mmengine - INFO - Iter(train) [ 410/1200]  lr: 1.5346e-04  eta: 3:05:24  time: 14.0401  data_time: 0.0541  memory: 14172  loss: 0.1111
05/09 13:00:56 - mmengine - INFO - Iter(train) [ 420/1200]  lr: 1.5116e-04  eta: 3:03:12  time: 14.5506  data_time: 0.0393  memory: 14170  loss: 0.0110
05/09 13:03:21 - mmengine - INFO - Iter(train) [ 430/1200]  lr: 1.4883e-04  eta: 3:00:58  time: 14.4790  data_time: 0.0370  memory: 14170  loss: 0.8442
05/09 13:05:41 - mmengine - INFO - Iter(train) [ 440/1200]  lr: 1.4645e-04  eta: 2:58:35  time: 13.9606  data_time: 0.0434  memory: 14170  loss: 0.0277
05/09 13:08:01 - mmengine - INFO - Iter(train) [ 450/1200]  lr: 1.4405e-04  eta: 2:56:12  time: 14.0194  data_time: 0.0314  memory: 14168  loss: 0.0015
05/09 13:10:24 - mmengine - INFO - Iter(train) [ 460/1200]  lr: 1.4161e-04  eta: 2:53:55  time: 14.3201  data_time: 0.0301  memory: 14172  loss: 0.0922
05/09 13:12:47 - mmengine - INFO - Iter(train) [ 470/1200]  lr: 1.3914e-04  eta: 2:51:36  time: 14.2605  data_time: 0.0315  memory: 14170  loss: 0.0422
05/09 13:15:10 - mmengine - INFO - Iter(train) [ 480/1200]  lr: 1.3664e-04  eta: 2:49:18  time: 14.2796  data_time: 0.0325  memory: 14170  loss: 0.0175
05/09 13:17:34 - mmengine - INFO - Iter(train) [ 490/1200]  lr: 1.3412e-04  eta: 2:47:02  time: 14.4799  data_time: 0.0321  memory: 14170  loss: 0.0168
05/09 13:20:01 - mmengine - INFO - Iter(train) [ 500/1200]  lr: 1.3157e-04  eta: 2:44:48  time: 14.6202  data_time: 0.0311  memory: 14169  loss: 0.0106
05/09 13:20:01 - mmengine - INFO - after_train_iter in EvaluateChatHook.

在训练好之后,将原始 image projector 和 微调得到的 image projector 都转换为 HuggingFace 格式,为了下面的效果体验做准备。

xtuner convert pth_to_hf ~/Llama3-Tutorial/configs/llama3-llava/llava_llama3_8b_instruct_qlora_clip_vit_large_p14_336_lora_e1_finetune.py \
  ~/model/llama3-llava-iter_2181.pth \
  ~/llama3_llava_pth/pretrain_iter_2181_hf

原始 image projector转换成HuggingFace格式,如下图所示:在这里插入图片描述
微调得到的 image projector 转换为 HuggingFace 格式:

xtuner convert pth_to_hf ~/Llama3-Tutorial/configs/llama3-llava/llava_llama3_8b_instruct_qlora_clip_vit_large_p14_336_lora_e1_finetune.py \
  ~/llama3_llava_pth/iter_1200.pth \
  ~/llama3_llava_pth/iter_1200_hf

在转换完成后,我们就可以在命令行简单体验一下微调后模型的效果了。

问题1:Describe this image.
问题2:What is the equipment in the image?

Pretrain 模型

export MKL_SERVICE_FORCE_INTEL=1
xtuner chat /root/model/Meta-Llama-3-8B-Instruct \
  --visual-encoder /root/model/clip-vit-large-patch14-336 \
  --llava /root/llama3_llava_pth/pretrain_iter_2181_hf \
  --prompt-template llama3_chat \
  --image /root/tutorial/xtuner/llava/llava_data/test_img/oph.jpg

在这里插入图片描述
此时可以看到,Pretrain 模型只会为图片打标签,并不能回答问题。
Finetune 后 模型

export MKL_SERVICE_FORCE_INTEL=1
xtuner chat /root/model/Meta-Llama-3-8B-Instruct \
  --visual-encoder /root/model/clip-vit-large-patch14-336 \
  --llava /root/llama3_llava_pth/iter_1200_hf \
  --prompt-template llama3_chat \
  --image /root/tutorial/xtuner/llava/llava_data/test_img/oph.jpg

在这里插入图片描述
经过 Finetune 后,我们可以发现,模型已经可以根据图片回答我们的问题了。本次实验参考Llama3-Tutorial这个教程,有兴趣者可以访问了解一下。

  • 14
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值