[CLIP-VIT-L + Qwen] 多模态大模型源码阅读 - MultiModal篇
参考repo:WatchTower-Liu/VLM-learning; url: VLLM-BASE
前情提要
有关多模态大模型架构中的语言模型部分(MQwen.py)的代码请看(多模态大模型源码阅读 - 1、 多模态大模型源码阅读 - 2, 多模态大模型源码阅读 - 3,多模态大模型源码阅读 - 4),多模态大模型架构中的视觉模型(visual/CLIP-VIT.py)部分请看多模态大模型源码阅读 - 5,多模态大模型架构中的trainer(trainer.py)部分请看多模态大模型源码阅读 - 6。
本节将讲解如何将之前重构的MQwen语言模型部分和CLIP-VIT视觉模型部分整合为MultiModal多模态模型类,并利用多模态模型类进行前向传播,生成预测内容。
源码阅读
导包
import torch
from torch import nn
from typing import Optional
import os
import sys
sys.path.append("../")
from transformers.modeling_outputs import CausalLMOutputWithPast
from dataclasses import dataclass, asdict
from peft import get_peft_model, LoraConfig, TaskType, PeftModel
from visual.CLIP_VIT import visualModel
from qwen.Mqwen import MQWenLMHeadModel
逐行讲解
对于部分已经用了无数次的模块就不再赘述了~
from typing import Optional
typing模块最重要的就是类型注释功能,这里导入的Optional表示变量可以是制定的类型或者None。例如Optional[str]表示变量可以是str类型或者None。
import sys
sys.path.append("../")
将上一层级目录添加到系统路径中,可以将上一层级的模块直接通过模块名导入。例如上一层级目录中定义了一个叫做abc.py的模块,那么就可以通过import abc直接导入。
from transformers.modeling_outputs import CausalLMOutputWithPast
from dataclasses import dataclass, asdict
CausalLMOutputWithPast专门用于封装因果模型的输出,包含了模型输出和过去的隐藏状态。
dataclass装饰器用于封装数据类型,asdict可以将数据类实例转换为字典。
from peft import get_peft_model, LoraConfig, TaskType, PeftModel
peft用于模型微调,get_peft_model方法获取LoRA,prefix tuning等不同类别的微调模型,LoRA包含了LoRA模型的必要配置参数,TaskType定义模型执行的不同任务类型,如文本分类、摘要总结等。PeftModel是一个基类,指定PEFT的配置。
dataclass部分
@dataclass
class LanguageConfig():
model_path: str
torch_dtype: torch.dtype = torch.bfloat16
trust_remote_code: bool = True
@dataclass
class VisualConfig():
model_path: str
pretrained: bool = True
@dataclass
class MultiModalConfig():
replace_token_id: int
# image_context_length: int = 256
image_context_length: int = 728
image_feature_hidden_size: int = 4096
整体含义
用于封装不同配置下的参数类型和初始值。
逐行解读
@dataclass
class LanguageConfig():
model_path: str
torch_dtype: torch.dtype = torch.bfloat16
trust_remote_code: bool = True
LanguageConfig类用于存储和管理语言模型的参数和配置。
model_path代表模型的存储地址,通常为字符串类型,无初始值,需要用户手动传入。
torch_type代表模型使用的数据类型,这里使用半精度浮点数bfloat16
trust_remote_code默认为True,当我们要远程从huggingface加载预训练模型时,通常需要保持这个值为True,因为我们运行的不是本地代码,本地下载模型的可以无视。
@dataclass
class VisualConfig():
model_path: str
pretrained: bool = True
VisualConfig代表视觉模型的参数和配置类,model_path与上文相同。
pretrained代表是否加载预训练模型的权重。
@dataclass
class MultiModalConfig():
replace_token_id: int
# image_context_length: int = 256
image_context_length: int = 728
image_feature_hidden_size: int = 4096
MultiModalConfig代表多模态模型的参数和配置类。
replace_token_id指定input_ids中用于替换的token_id,例如输入为[102,103,101]的数据,指定101为replace_token_id,则将101替换为图片特征数据。
image_context_length代表图像上下文长度。
image_feature_hidden_size指定图像特征隐藏层维度大小。
模型微调
def make_lora(model, finetune_args):
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=finetune_args.lora_rank,
lora_alpha=32,
lora_dropout=finetune_args.lora_dropout,
target_modules = finetune_args.target_modules.split('|')