源码路径:https://github.com/haotian-liu/LLaVA/llava/model/train/train.py
声明:此博客为个人学习笔记,可能有理解错误的地方,敬请指正
函数定义
def train(attn_implementation=None):
global local_rank
attn_implementation:注意力机制接口,此参数决定模型使用何种注意力机制,此参数会在后面模型定义处使用
local_rank:在分布式训练中使用,表示当前训练实例(进程)的本地排名或编号。在分布式训练时,local_rank
用于指定每个进程的任务(比如在多卡训练时,每张GPU会有一个 local_rank
)
参数解析
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
local_rank = training_args.local_rank
compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
1.使用HuggingFace提供的专门解析命令行参数的类,并返回类实例
2.根据解析后的参数给local_rank和compute_dtype赋值
模型加载与量化参数配置
bnb_model_from_pretrained_args = {}
if training_args.bits in [4, 8]:
from transformers import BitsAndBytesConfig
bnb_model_from_pretrained_args.update(dict(
device_map={"": training_args.device},
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
quantization_config=BitsAndBytesConfig(
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
llm_int8_skip_modules=["mm_projector"],
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=training_args.double_quant,
bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
)
))
此处根据training_args中的参数构造了字典bnb_model_from_pretrained_args,配置了与bits_and_bytes相关的模型加载参数,特别是量化时加载模型的方式,为的是减少显存占用
定义模型:获得model
if model_args.vision_tower is not None:
if 'mpt' in model_args.model_name_or_path:
config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
config.attn_config['attn_impl'] = training_args.mpt_attn_impl
model = LlavaMptForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
cache_dir=training_args.cache_dir,
**bnb_model_from_pretrained_args
)
else:
model = LlavaLlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
**bnb_model_from_pretrained_args
)
else:
model = transformers.LlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
**bnb_model_from_pretrained_args
)
model.config.use_cache = False
根据是否有视觉塔来定义不同的模型。可见LLAVA工程train函数的复用性较强,可支持多种模型的训练
冻结骨干网络
if model_args.freeze_backbone:
model.model.requires_grad_(False)
冻结骨干网络,只对新加入的层进行训练更新,有效防止过拟合且减小计算量。很好理解:一些预训练过的东西就没有必要动它了
进一步配置量化训练
if training_args.bits in [4, 8]:
from peft import prepare_model_for_kbit_training
model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
PEFT(Parameter Efficient Fine-Tuning)通常用于高效微调大模型
torch_dtype:模型精度
为低精度训练做好准备,并设置梯度检查点
启动部分层的梯度计算
if training_args.gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
register_forward_hook :是 PyTorch 中的一个方法,它允许你在模型前向传播时插入一个钩子函数(hook)。该钩子函数会在模型执行前向传播时被调用,make_inputs_require_grad
就是在这个时候被触发的。通过这个钩子函数,每当模型的输入嵌入层(input_embeddings
)被调用时,它都会确保其输出的 requires_grad
属性被设置为 True
,即启用梯度计算。
LoRA配置和模型转换
LoRA 是一种用于有效微调预训练模型的技术,特别是对于大模型的参数高效调整
if training_args.lora_enable:
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=training_args.lora_r,
lora_alpha=training_args.lora_alpha,
target_modules=find_all_linear_names(model),
lora_dropout=training_args.lora_dropout,
bias=training_args.lora_bias,
task_type="CAUSAL_LM",
)
if training_args.bits == 16:
if training_args.bf16:
model.to(torch.bfloat16)
if training_args.fp16:
model.to(torch.float16)
rank0_print("Adding LoRA adapters...")
model = get_peft_model(model, lora_config)
r:LoRA适配器的秩
lora_alpha:放大因子,决定LORA的影响力
target_moudles:目标模块,这里是一些线性层
rank0_print:主进程打印提示信息
get_peft_model:对模型进行进一步处理
定义Tokenizer
if 'mpt' in model_args.model_name_or_path:
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right"
)
else:
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
)
mpt:Multi-Modal Pretrained Transformer是一种能够处理多模态输入的 Transformer 模型
use_fast:是否使用fast tokenizer比较高效,但不支持一些情况,所以在第二种情况下使用的是标准Python tokenizer
配置Tokenizer
if model_args.version == "v0":
if tokenizer.pad_token is None:
smart_tokenizer_and_embedding_resize(
special_tokens_dict=dict(pad_token="[PAD]"),
tokenizer=tokenizer,
model=model,
)
elif model_args.version == "v0.5":
tokenizer.pad_token = tokenizer.unk_token
else:
tokenizer.pad_token = tokenizer.unk_token
if model_args.version in conversation_lib.conv_templates:
conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
else:
conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]
vision_tower和vision_tokenizer初始化
if model_args.vision_tower is not None:
model.get_model().initialize_vision_modules(
model_args=model_args,
fsdp=training_args.fsdp
)
vision_tower = model.get_vision_tower()
vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
data_args.image_processor = vision_tower.image_processor
data_args.is_multimodal = True
model.config.image_aspect_ratio = data_args.image_aspect_ratio
model.config.tokenizer_padding_side = tokenizer.padding_side
model.config.tokenizer_model_max_length = tokenizer.model_max_length
model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
if model_args.tune_mm_mlp_adapter:
model.requires_grad_(False)
for p in model.get_model().mm_projector.parameters():
p.requires_grad = True
model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
if training_args.freeze_mm_mlp_adapter:
for p in model.get_model().mm_projector.parameters():
p.requires_grad = False
if training_args.bits in [4, 8]:
model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)
model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
model.config.mm_projector_lr = training_args.mm_projector_lr
training_args.use_im_start_end = model_args.mm_use_im_start_end
model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
看似一顿操作,其实就是为了配置一个vision_tower和vision_tokenier,这两个都在model里
对模型中的部分层做数据类型转换
if training_args.bits in [4, 8]:
from peft.tuners.lora import LoraLayer
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
if training_args.bf16:
module = module.to(torch.bfloat16)
if 'norm' in name:
module = module.to(torch.float32)
if 'lm_head' in name or 'embed_tokens' in name:
if hasattr(module, 'weight'):
if training_args.bf16 and module.weight.dtype == torch.float32:
module = module.to(torch.bfloat16)
训练并保存模型
data_module = make_supervised_data_module(tokenizer=tokenizer,
data_args=data_args)
trainer = LLaVATrainer(model=model,
tokenizer=tokenizer,
args=training_args,
**data_module)
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
model.config.use_cache = True
if training_args.lora_enable:
state_dict = get_peft_state_maybe_zero_3(
model.named_parameters(), training_args.lora_bias
)
non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
model.named_parameters()
)
if training_args.local_rank == 0 or training_args.local_rank == -1:
model.config.save_pretrained(training_args.output_dir)
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
else:
safe_save_model_for_hf_trainer(trainer=trainer,
output_dir=training_args.output_dir)
这里trainer.save_state()是保存模型当前状态,防止中断后丢失进度,和保存模型并不同
个人总结:
LoRA是一种大规模预训练模型微调的方法,日后有时间深入学习
还有那些数据类型转换的部分理解不够深入,还需要加强学习模型量化和分布式训练的内容