使用fsdp_qlora/README.md at main · AnswerDotAI/fsdp_qlora · GitHub这个框架进行llama2的finetune训练
该框架支持不同的训练模式,使用HQQ(Half-Quadratic Quantization)和Bitsandbytes两种不同的量化技术
Training Options
For quantization we support HQQ and bitsandbytes. We're currently doing benchmarking to help you decide which to use. If you do use bitsandbytes, be sure to pass --reentrant_checkpointing True
to avoid triggering a bug in bitsandbytes which results in high memory usage (a fix is in progress).
在使用"hqq_lora"进行训练时
model = FSDP(
model,
sharding_strategy=sharding_strategy,
auto_wrap_policy=my_auto_wrap_policy,
# backward_prefetch=None, #BackwardPrefetch.BACKWARD_PRE
use_orig_params=False,
cpu_offload=CPUOffload(offload_params=True) if args["use_cpu_offload"] else None,
limit_all_gathers=True, # See https://github.com/pytorch/pytorch/issues/91165
device_id=torch.cuda.current_device(),
sync_module_states=args["low_memory"],
param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
if (rank != 0 and args["low_memory"]) else None, # TODO note about meta device and why we need this
mixed_precision=mp_policy,
)
下面这一行一直报错:NotImplementedError: Cannot copy out of meta tensor; no data!
param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
使用了
自定义的to_empty
方法---旨在将模块的参数转移到CUDA设备上。但在此过程中,HQQ
模块中的某个量化参数(W_q
)似乎还在元设备上,且尝试直接将其复制到CUDA设备时失败了
合理怀疑是在模型的wrap阶段:
# Wrap model with llama-recipies or custom LoRA policy
my_auto_wrap_policy = get_wrapping_policy(
custom_policy=args["train_type"] in ["custom_qlora", "hqq_lora", "hqq_dora", "bnb_dora"],
vanilla_policy=args["train_type"] in ["full", "bnb_llama_pro", "hqq_llama_pro"])
元设备上的张量是用于构建和测试,而没有实际的数据载体,所以无法直接转移
于是自作主张,将HQQ框架中的to_empty替换为module.nn框架下的
def to_empty(self, device, recurse=True):
# return self.cuda(device) #原来的
return self._apply(lambda t: torch.empty_like(t, device=device), recurse=recurse)
报错就解除了
自己的理解是:
1)原来的方法使用self.cuda(device)
试图将包含在模块中的所有参数和缓冲区转移到指定的CUDA设备上。然而,当模块中包含元设备(meta device)上的张量时,这种直接转移的尝试因为元张量不包含实际数据而失败,导致了NotImplementedError
错误。
2)替换的方法对于模块中的每一个张量t
,这个函数创建一个新的、在指定device
(比如GPU)上的空张量。这个新张量有着与t
相同的形状和类型,但不包含任何数据
太深奥了,仍然半知半解