LLAVA代码阅读:train.py

源码路径:https://github.com/haotian-liu/LLaVA/llava/model/train/train.py

声明:此博客为个人学习笔记,可能有理解错误的地方,敬请指正

函数定义

def train(attn_implementation=None):
    global local_rank

attn_implementation:注意力机制接口,此参数决定模型使用何种注意力机制,此参数会在后面模型定义处使用

local_rank:在分布式训练中使用,表示当前训练实例(进程)的本地排名或编号。在分布式训练时,local_rank 用于指定每个进程的任务(比如在多卡训练时,每张GPU会有一个 local_rank

参数解析

parser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    local_rank = training_args.local_rank
    compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))

1.使用HuggingFace提供的专门解析命令行参数的类,并返回类实例

2.根据解析后的参数给local_rank和compute_dtype赋值

模型加载与量化参数配置

    bnb_model_from_pretrained_args = {}
    if training_args.bits in [4, 8]:
        from transformers import BitsAndBytesConfig
        bnb_model_from_pretrained_args.update(dict(
            device_map={"": training_args.device},
            load_in_4bit=training_args.bits == 4,
            load_in_8bit=training_args.bits == 8,
            quantization_config=BitsAndBytesConfig(
                load_in_4bit=training_args.bits == 4,
                load_in_8bit=training_args.bits == 8,
                llm_int8_skip_modules=["mm_projector"],
                llm_int8_threshold=6.0,
                llm_int8_has_fp16_weight=False,
                bnb_4bit_compute_dtype=compute_dtype,
                bnb_4bit_use_double_quant=training_args.double_quant,
                bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
            )
        ))

此处根据training_args中的参数构造了字典bnb_model_from_pretrained_args,配置了与bits_and_bytes相关的模型加载参数,特别是量化时加载模型的方式,为的是减少显存占用

定义模型:获得model

    if model_args.vision_tower is not None:
        if 'mpt' in model_args.model_name_or_path:
            config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
            config.attn_config['attn_impl'] = training_args.mpt_attn_impl
            model = LlavaMptForCausalLM.from_pretrained(
                model_args.model_name_or_path,
                config=config,
                cache_dir=training_args.cache_dir,
                **bnb_model_from_pretrained_args
            )
        else:
            model = LlavaLlamaForCausalLM.from_pretrained(
                model_args.model_name_or_path,
                cache_dir=training_args.cache_dir,
                attn_implementation=attn_implementation,
                torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
                **bnb_model_from_pretrained_args
            )
    else:
        model = transformers.LlamaForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=training_args.cache_dir,
            attn_implementation=attn_implementation,
            torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
            **bnb_model_from_pretrained_args
        )
    model.config.use_cache = False

根据是否有视觉塔来定义不同的模型。可见LLAVA工程train函数的复用性较强,可支持多种模型的训练

冻结骨干网络

if model_args.freeze_backbone:
    model.model.requires_grad_(False)

冻结骨干网络,只对新加入的层进行训练更新,有效防止过拟合且减小计算量。很好理解:一些预训练过的东西就没有必要动它了

进一步配置量化训练

    if training_args.bits in [4, 8]:
        from peft import prepare_model_for_kbit_training
        model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
        model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)

PEFT(Parameter Efficient Fine-Tuning)通常用于高效微调大模型  

torch_dtype:模型精度

为低精度训练做好准备,并设置梯度检查点

启动部分层的梯度计算

    if training_args.gradient_checkpointing:
        if hasattr(model, "enable_input_require_grads"):
            model.enable_input_require_grads()
        else:
            def make_inputs_require_grad(module, input, output):
                output.requires_grad_(True)
            model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

register_forward_hook :是 PyTorch 中的一个方法,它允许你在模型前向传播时插入一个钩子函数(hook)。该钩子函数会在模型执行前向传播时被调用,make_inputs_require_grad 就是在这个时候被触发的。通过这个钩子函数,每当模型的输入嵌入层(input_embeddings)被调用时,它都会确保其输出的 requires_grad 属性被设置为 True,即启用梯度计算。

LoRA配置和模型转换

LoRA 是一种用于有效微调预训练模型的技术,特别是对于大模型的参数高效调整

 if training_args.lora_enable:
        from peft import LoraConfig, get_peft_model
        lora_config = LoraConfig(
            r=training_args.lora_r,
            lora_alpha=training_args.lora_alpha,
            target_modules=find_all_linear_names(model),
            lora_dropout=training_args.lora_dropout,
            bias=training_args.lora_bias,
            task_type="CAUSAL_LM",
        )
        if training_args.bits == 16:
            if training_args.bf16:
                model.to(torch.bfloat16)
            if training_args.fp16:
                model.to(torch.float16)
        rank0_print("Adding LoRA adapters...")
        model = get_peft_model(model, lora_config)

r:LoRA适配器的秩

lora_alpha:放大因子,决定LORA的影响力

target_moudles:目标模块,这里是一些线性层

rank0_print:主进程打印提示信息

get_peft_model:对模型进行进一步处理

定义Tokenizer

    if 'mpt' in model_args.model_name_or_path:
        tokenizer = transformers.AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=training_args.cache_dir,
            model_max_length=training_args.model_max_length,
            padding_side="right"
        )
    else:
        tokenizer = transformers.AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=training_args.cache_dir,
            model_max_length=training_args.model_max_length,
            padding_side="right",
            use_fast=False,
        )

mpt:Multi-Modal Pretrained Transformer是一种能够处理多模态输入的 Transformer 模型

use_fast:是否使用fast tokenizer比较高效,但不支持一些情况,所以在第二种情况下使用的是标准Python tokenizer

配置Tokenizer

    if model_args.version == "v0":
        if tokenizer.pad_token is None:
            smart_tokenizer_and_embedding_resize(
                special_tokens_dict=dict(pad_token="[PAD]"),
                tokenizer=tokenizer,
                model=model,
            )
    elif model_args.version == "v0.5":
        tokenizer.pad_token = tokenizer.unk_token
    else:
        tokenizer.pad_token = tokenizer.unk_token
        if model_args.version in conversation_lib.conv_templates:
            conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
        else:
            conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]

vision_tower和vision_tokenizer初始化

    if model_args.vision_tower is not None:
        model.get_model().initialize_vision_modules(
            model_args=model_args,
            fsdp=training_args.fsdp
        )

        vision_tower = model.get_vision_tower()
        vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)

        data_args.image_processor = vision_tower.image_processor
        data_args.is_multimodal = True

        model.config.image_aspect_ratio = data_args.image_aspect_ratio
        model.config.tokenizer_padding_side = tokenizer.padding_side
        model.config.tokenizer_model_max_length = tokenizer.model_max_length

        model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
        if model_args.tune_mm_mlp_adapter:
            model.requires_grad_(False)
            for p in model.get_model().mm_projector.parameters():
                p.requires_grad = True

        model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
        if training_args.freeze_mm_mlp_adapter:
            for p in model.get_model().mm_projector.parameters():
                p.requires_grad = False

        if training_args.bits in [4, 8]:
            model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)

        model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
        model.config.mm_projector_lr = training_args.mm_projector_lr
        training_args.use_im_start_end = model_args.mm_use_im_start_end
        model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
        model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)

看似一顿操作,其实就是为了配置一个vision_tower和vision_tokenier,这两个都在model里

对模型中的部分层做数据类型转换

    if training_args.bits in [4, 8]:
        from peft.tuners.lora import LoraLayer
        for name, module in model.named_modules():
            if isinstance(module, LoraLayer):
                if training_args.bf16:
                    module = module.to(torch.bfloat16)
            if 'norm' in name:
                module = module.to(torch.float32)
            if 'lm_head' in name or 'embed_tokens' in name:
                if hasattr(module, 'weight'):
                    if training_args.bf16 and module.weight.dtype == torch.float32:
                        module = module.to(torch.bfloat16)

训练并保存模型

data_module = make_supervised_data_module(tokenizer=tokenizer,
                                              data_args=data_args)
    trainer = LLaVATrainer(model=model,
                    tokenizer=tokenizer,
                    args=training_args,
                    **data_module)

    if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
        trainer.train(resume_from_checkpoint=True)
    else:
        trainer.train()
    trainer.save_state()

    model.config.use_cache = True

    if training_args.lora_enable:
        state_dict = get_peft_state_maybe_zero_3(
            model.named_parameters(), training_args.lora_bias
        )
        non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
            model.named_parameters()
        )
        if training_args.local_rank == 0 or training_args.local_rank == -1:
            model.config.save_pretrained(training_args.output_dir)
            model.save_pretrained(training_args.output_dir, state_dict=state_dict)
            torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
    else:
        safe_save_model_for_hf_trainer(trainer=trainer,
                                       output_dir=training_args.output_dir)

这里trainer.save_state()是保存模型当前状态,防止中断后丢失进度,和保存模型并不同

个人总结:

LoRA是一种大规模预训练模型微调的方法,日后有时间深入学习

还有那些数据类型转换的部分理解不够深入,还需要加强学习模型量化和分布式训练的内容

### 解决方案 在解决 `llava.constants`、`llava.conversation`、`llava.model.builder`、`llava.utils` 和 `llava.mm_utils` 的路径缺失问题时,可以按照以下方式操作: #### 1. **确认依赖文件是否存在** 如果这些模块的路径缺失,可能是因为安装过程中未正确克隆完整的 LLaVA或者缺少必要的子模块。可以通过重新克隆官方仓来解决问题[^2]。 ```bash git clone https://github.com/haotian-liu/LLaVA.git cd LLaVA pip install -r requirements.txt ``` 确保所有依赖项已正确安装,并验证是否有遗漏的 Python 文件或目录。 --- #### 2. **手动创建缺失模块** 如果某些模块确实不存在于当前环境中,则可以根据需求自行定义这些模块的内容。以下是各模块的功能概述以及如何实现它们的方法: ##### a. **`llava.constants`** 该模块通常用于存储全局常量变量。例如模型名称、默认参数等。可参考以下代码结构: ```python # llava/constants.py DEFAULT_MODEL_PATH = "liuhaotian/llava-v1.5-7b" IMAGE_TOKEN_INDEX = "<image>" CONTEXT_LENGTH = 2048 ``` --- ##### b. **`llava.conversation`** 此模块负责管理对话历史记录和上下文信息。其核心功能是对输入数据进行预处理以便传递给模型。示例代码如下: ```python # llava/conversation.py class Conversation: def __init__(self): self.history = [] def add_message(self, role, content): """Add message to conversation history.""" self.history.append({"role": role, "content": content}) def get_prompt(self): """Generate prompt from conversation history.""" return "\n".join(f"{msg['role']}: {msg['content']}" for msg in self.history) ``` --- ##### c. **`llava.model.builder`** 这个模块主要用于构建和加载模型实例。它调用了 Hugging Face Transformers 中的函数完成具体工作。下面是一个简单的例子: ```python from transformers import AutoTokenizer, AutoModelForCausalLM # llava/model/builder.py def build_model(model_path, device="cuda"): tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, trust_remote_code=True, device_map=device ) return tokenizer, model ``` --- ##### d. **`llava.utils`** 工具类模块一般提供辅助性的静态方法供其他组件调用。比如日志打印、错误检测等功能都可以放在这里面。举个简单案例: ```python # llava/utils.py import logging logger = logging.getLogger(__name__) def setup_logger(level=logging.INFO): handler = logging.StreamHandler() formatter = logging.Formatter("%(asctime)s %(levelname)-8s %(message)s") handler.setFormatter(formatter) logger.addHandler(handler) logger.setLevel(level) setup_logger() ``` --- ##### e. **`llava.mm_utils`** 多媒体实用程序模块主要涉及图像编码解码过程以及其他跨模态任务支持逻辑。基于引用描述[^4],我们可以编写类似的函数来满足实际用途: ```python import torch.nn as nn # llava/mm_utils.py class MMProjector(nn.Module): def __init__(self, input_dim=1152, output_dim=1536): super().__init__() self.projection_layers = nn.Sequential( nn.Linear(input_dim, output_dim), nn.GELU(), nn.Linear(output_dim, output_dim) ) def forward(self, features): return self.projection_layers(features) def map_image_to_text(vision_output): projector = MMProjector() mapped_tensor = projector(vision_output) return mapped_tensor ``` --- #### 3. **调整配置文件** 除了修复上述模块外,还需要检查项目根目录下的 `config.json` 是否已经更新至最新版本[^3]。特别是当引入新的视觉塔(Vision Tower)如 CLIP-ViT-Large-Patch14-336 后更要注意同步修改相关内容设置。 --- ### 总结 通过以上步骤应该能够有效缓解因路径丢失引发的一系列连锁反应问题。当然,在执行任何更改之前建议备份原始工程以防万一出现问题无法回滚恢复原状。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值