本文是LLM系列文章,针对《JORA: JAX Tensor-Parallel LoRA Library for Retrieval Augmented
Fine
摘要
用于基于检索的任务的大型语言模型(LLM)的扩展,特别是在检索增强生成(RAG)中,面临着巨大的内存限制,尤其是在微调大量提示序列时。当前的开源库支持跨多个GPU的全模型推理和微调,但无法适应检索上下文所需的高效参数分布。为了弥补这一差距,我们引入了一种新的框架,利用分布式训练对Llama-2模型进行PEFT兼容的微调。我们的框架独特地利用了JAX的实时(JIT)编译和张量分片来实现高效的资源管理,从而加速了微调,降低了内存需求。这一进步显著提高了为复杂RAG应用程序微调LLM的可扩展性和可行性,即使在GPU资源有限的系统上也是如此。我们的实验表明,与使用四个GPU的Hugging Face/DeepSpeed实现相比,运行时间提高了12倍以上,而每个GPU消耗的VRAM不到一半。