从零训练一个多模态LLM:预训练+指令微调+对齐+融合多模态+链接外部系统

图片

本文尝试梳理一个完整的多模态LLM的训练流程。包括模型结构选择、数据预处理、模型预训练、指令微调、对齐、融合多模态以及链接外部系统等环节。

01

图片

准备阶段

1 模型结构

目前主要有三种模型架构,基于Transformer解码器,基于General Language Model,以及混合专家模型。这一步可以直接选择开源的的基座模型,例如基于Transformer解码器架构的LLaMA模型族,模型结构及一些重要参数如下图。假设选择LLaMA-65B,Tokenizer选择LLaMA的基于BPE算法构造的tokenizer。如果想要扩展词表,可以在目标语言上训练好词表后和LLaMA的词表merge在一起。

图片

02

图片

预训练数据

1 数据源

根据Chinchilla 的scaling law,要达到最优的计算利用率,65B模型对应的训练token数量应该达到1.4T。当前用于训练LLM的数据来源很多,但其中的高质量数据有限,该数据是提升模型性能的关键。另外,有文章指出,代码数据有助于提升模型的推理能力。因此,需要混合多种数据来源的数据,并合理分配每周数据的占比。如下图,可以参考LLaMA的数据源和比例,其中Disk size表示可用的数据总量,Sampling prop表示在总训练token中的占比,epochs表示采样的次数。

图片
2 数据处理

想要提升模型的性能,除了利用已有的开源数据集,例如The Pile,C4,OSCAR等,还可以自己构建数据集。在论文RefinedWeb 中提到,从网页数据中创建的数据集也可以达到和精心收集的数据集同等的效果。Wikipedia,Books,GitHub,ArXiv,以及StackExchange等高质量数据的处理方法可以参考论文The Pile 。下面介绍从Common Crawl构建数据集的方法,CC是一个海量的、非结构化的、多语言的网页数据集,拥有超过8年的网络爬虫数据集,包括原始网页数据(WARC)、元数据(WAT)和文本提取(WET)。每个月都会发布包括一个 20~30 TB 未压缩纯文本的快照,包含了随机搜索和采样的 URL 所获得的网页,不同月份发布的数据之间只有非常少量的数据重合,8 年以来所有爬下来的数据总和是PB级别的。如下图,数据处理流程包含文档准备,过滤以及去重三个步骤,做法参考RefinedWeb论文。

图片
2.1 文档准备

包括数据读取、过滤url、提取文本和语言识别;

数据读取

文本数据既可以从WET文件也可以从WARC文件中读取。直接使用WET文件可以省略从HTML文件中提取文本的工作,但是包含一些不相关信息。因此可以从WARC文件中读取文本。

过滤URL

在正式处理文本数据之前,首先要对URL执行第一次过滤,过滤的目标是欺诈和成人网站(主要是色情、暴力、与赌博有关的网站等)。基于两个规则进行过滤:(1)一个包含460万个域名的屏蔽列表;(2) URL评分,基于收集到的特定单词列表,并按严重程度进行权衡。同时,可以按照需要过滤掉包含在高文本数据集中的数据来源,例如Wikipedia和arXiv等。

提取文本

目的是提取HTML页面中的主要内容,忽略菜单、页眉、页脚和广告等。可以采用trafilatura,jusText等库,结合正则表达式进行文本提取。最终将新行限制为连续的两行,并删除所有URL链接。

语言识别

语言识别可以在去重之前也可以在去重之后进行。但当文档数量比较少的时候,先识别会导致部分语言分类错误。可以采用fastText语言分类器进行语言分类,该分类器是在Wikipedia、Tatoeba和SETimes上面训练的,使用n-grams来作为特征,并采用层级softmax,支持 176 种语言的分类,最后输出一个 0~1 的分数。删除最高语言分数低于设定阈值的文档。通过改变阈值,可以调整保留的文档比例。

2.2 过滤

从网页提取的文档质量低下,过滤的目的是移除重复段落,无关内容,非自然语言等等,提高文本质量。包括文档级别和行级过滤;

包含重复的文档移除

可以在去重阶段进行,但在早期进行代价更低,也更容易。一般采用启发式方法,制定一系列规则删除任何具有过多行、段落或n-gram重复的文档,做法可以参考论文BLOOM 。

文档过滤

主要的目的是保留人类写给人类的自然语言文档,移除机器生成的垃圾邮件,主要由关键字列表、样板文本或特殊字符序列组成。这样的文档不适合语言建模。采用质量过滤启发式算法,做法可以参考论文BLOOM。重点是根据文档长度、符号与单词的比率和其他标准方面去除异常值,以确保文档是由真正的自然语言构成。

行级过滤

通过trafilatura库提取的文本避免了大部分无关的内容,但仍然有遗漏。通过一个线性校正过滤器继续过滤和正文无关的内容(例如点赞数,导航按钮等)。

2.3 去重

过滤之后,数据质量得到了提高,但很多文档是重复的。可以通过模糊文档匹配和精确序列删除对文档进行去重。

模糊去重

可以采用SimHash,MinHash算法删除相似的文档:对于每个文档,计算其与其他文档的近似相似性,并删除高重叠的文档对。通过更改哈希算法的参数,可以调整去重的比例。

精确去重

一般采用精确子字符串去重,是序列级去重。通过使用后缀数组查找字符串之间的精确匹配,删除重复超过给定阈值的连续token的段落。

URL去重

进一步删除跨CC转储重复访问的URL。

03

图片

模型预训练

基于Transformer解码器架构的LM的预训练的方法是让模型做 Next Token Prediction 任务。基于GLM的LM的预训练方法是让模型做自回归空白填充任务。LLM由于规模大,权重维度高,参数量以及数据量多,因此会带来训练不稳定,难以收敛,耗时长,计算资源庞大等问题。下面从模型结构和训练技巧方面介绍一些提升模型的训练速度以及提高训练稳定性的方法。

1 结构改进

采用Pre-normalization, 例如RMSNorm,在残差连接前对参数归一化,有一部分参数直接与后面的参数相加,可以防止梯度爆炸或消失。用 SwiGLU 激活替代 ReLU,提高模型性能。相较于常规的绝对位置编码,相对位置编码在长序列上性能更好,例如RoPE和ALiBi编码方法。常规的自注意力查询需要大量的计算和存储资源,通过采用 Multi Query Attention 和 Flash Attention 可以减少计算,提升训练速度。

<think>好的,用户现在想了解如何从开始编写多模态大模型的代码实现教程。我需要结合之前的引用内容,特别是引用[2]、[3]、[4]中的信息来构建回答。首先,用户可能对多模态大模型的基本结构不太清楚,所以需要先解释核心概念,比如特征对齐、跨模态交互等。然后,根据引用[2]提到的从构建LLM,可能需要分步骤讲解,比如数据准备、模型架构、预训练微调等阶段。 接下来,引用[3]提到了LLaVA的训练阶段,包括特征对齐和视觉指令微调,这可以作为代码实现的具体步骤。例如,特征对齐阶段需要训练MLP连接器,这可能涉及到将图像和文本特征映射到同一空间。需要展示如何用PyTorch实现这样的连接层,并给出示例代码。 另外,引用[4]中的RAG+多模态模型实战可能涉及到API密钥的设置,但用户的问题更偏向于代码实现,所以可能需要忽略这部分,专注于模型架构和训练流程。同时,要确保代码示例符合系统指令,比如数学公式用$...$格式,独立公式用$$...$$,并检查LaTeX语法是否正确。 用户还可能关心如何整合不同模态的数据,比如文本和图像的处理方法。需要解释如何处理图像特征(如使用ResNet提取),文本特征(如使用BERT编码),然后通过连接层进行对齐。此外,训练策略如两阶段训练预训练微调)也是重点,需要详细说明每个阶段的数据集和损失函数。 可能还需要考虑多模态交互机制,比如交叉注意力或融合层,这部分可以引用Transformer的结构,如引用[2]中的Self-Attention实现,展示如何将不同模态的特征进行交互。代码示例应包括定义模型类、前向传播方法,以及训练循环的基本结构。 最后,生成相关问题时要围绕用户可能的后续疑问,比如优化方法、部署应用、数据处理技巧等,确保问题覆盖全面,帮助用户进一步深入学习。同时,确保回答结构清晰,分步骤讲解,代码正确,符合用户的需求。</think>### 从构建多模态大模型的代码实现教程 #### 1. 核心架构设计 多模态大模型的核心是**跨模态特征对齐融合**。典型结构包含: - **图像编码器**:使用预训练的ResNet/ViT提取视觉特征 - **文本编码器**:使用BERT/GPT处理文本特征 - **跨模态交互层**:通过注意力机制实现特征融合 - **任务适配层**:输出预测结果 数学表达式描述特征对齐过程: $$ h_v = W_v \cdot \text{ResNet}(I),\quad h_t = W_t \cdot \text{BERT}(T) $$ 其中$W_v, W_t$是可训练的映射矩阵[^3] #### 2. 代码实现步骤(基于PyTorch) **步骤1:定义跨模态连接层** ```python import torch import torch.nn as nn class MultimodalProjector(nn.Module): def __init__(self, vis_dim=2048, txt_dim=768, hidden_dim=512): super().__init__() self.image_proj = nn.Sequential( nn.Linear(vis_dim, hidden_dim), nn.GELU(), nn.LayerNorm(hidden_dim) ) self.text_proj = nn.Sequential( nn.Linear(txt_dim, hidden_dim), nn.GELU(), nn.LayerNorm(hidden_dim) ) def forward(self, image_features, text_features): return self.image_proj(image_features), self.text_proj(text_features) ``` **步骤2:构建跨模态注意力层** ```python class CrossAttention(nn.Module): def __init__(self, dim=512, heads=8): super().__init__() self.mha = nn.MultiheadAttention(dim, heads) self.norm = nn.LayerNorm(dim) def forward(self, query, key_value): attn_out, _ = self.mha(query, key_value, key_value) return self.norm(query + attn_out) ``` **步骤3:整合完整模型** ```python class MultimodalModel(nn.Module): def __init__(self): super().__init__() self.vision_encoder = torch.hub.load('pytorch/vision', 'resnet50', pretrained=True) self.text_encoder = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-uncased') self.projector = MultimodalProjector() self.cross_attn = CrossAttention() self.classifier = nn.Linear(512, num_classes) def forward(self, images, texts): # 特征提取 vis_features = self.vision_encoder(images) txt_features = self.text_encoder(texts).last_hidden_state.mean(1) # 特征对齐 proj_vis, proj_txt = self.projector(vis_features, txt_features) # 跨模态交互 fused_features = self.cross_attn(proj_vis.unsqueeze(0), proj_txt.unsqueeze(0)) # 任务预测 return self.classifier(fused_features.squeeze(0)) ``` #### 3. 训练策略实现 参考LLaVA的两阶段训练法[^3]: ```python # 第一阶段:特征对齐预训练 def pretrain_phase(): optimizer = torch.optim.AdamW(model.projector.parameters(), lr=1e-4) loss_fn = nn.CosineEmbeddingLoss() for batch in pretrain_loader: images, texts = batch vis_feat = model.vision_encoder(images) txt_feat = model.text_encoder(texts) proj_vis, proj_txt = model.projector(vis_feat, txt_feat) loss = loss_fn(proj_vis, proj_txt, torch.ones(batch_size)) optimizer.zero_grad() loss.backward() optimizer.step() # 第二阶段:指令微调 def finetune_phase(): optimizer = torch.optim.Adam(model.parameters(), lr=2e-5) loss_fn = nn.CrossEntropyLoss() for batch in finetune_loader: images, texts, labels = batch outputs = model(images, texts) loss = loss_fn(outputs, labels) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() ``` #### 4. 关键数学原理 跨模态注意力机制的核心公式: $$ \text{Attention}(Q,K,V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V $$ 其中$Q$来自一个模态,$K,V$来自另一模态,$d_k$是维度缩放因子[^2]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值