【多模态】27、Vary | 通过扩充图像词汇来提升多模态模型在细粒度感知任务(OCR等)上的效果

本文介绍了旷视提出的Vary方法,用于解决大型视觉-语言模型在特殊任务中视觉词汇不足的问题。Vary通过生成新视觉词汇表并整合新旧词汇表,实现更细粒度的视觉感知。文中阐述了方法的具体实现,包括网络结构、数据引擎等,还展示了其在文档理解等任务上的良好效果,并给出代码流程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

在这里插入图片描述

论文:Vary: Scaling up the Vision Vocabulary for Large Vision-Language Models

代码:https://github.com/Ucas-HaoranWei/Vary

出处:旷视

时间:2023.12

一、背景

当前流行的大型视觉-语言模型 Large Vision-Language Models (LVLMs) 一般都使用共享的 vision vocabulary,这个词库就是 CLIP,因为 CLIP 是公认的包含了很多图像-语言信息的模型,可以 cover 大多数通用的视觉任务。

但对于一些特殊的任务,需要对视觉信息进行更密集细致的提取,比如需要对 document-level 进行 OCR 或字符的理解的任务,或者非英文的场景,CLIP vision vocabulary 就无法表现的很好了。

mPlug-Owl [49] 和 Qwen-VL 尝试了将 vision vocabulary 网络解冻来解决这个问题,但作者认为,有三个不合理的地方:

  • 其一:这样会覆盖掉之前学习的知识:

    这意味着如果你尝试通过向大型语言模型(如7B大小的模型)添加或更新视觉词汇,可能会导致原有的、模型已经学习的词汇知识被新的信息覆盖。因为语言模型通常是在大量文本数据上训练的,它们已经内化了丰富的语言知识和结构。如果试图将视觉元素的信息强加给这些已经存在的词汇,可能会扰乱模型对这些词汇原有的理解。

  • 其二:vision vocabulary 更新的速度更快,大的 LLM(7B)的更新速度慢:

    在一个相对较大的语言模型上更新视觉词汇,训练效率可能会很低。这是因为大型模型参数众多,训练它们需要大量的计算资源和时间。尤其是当试图整合视觉数据时,这个过程可能会变得更加复杂和低效,因为视觉数据通常比文本数据更为复杂并且维度更高。

  • 其三:不能让视觉词汇网络多次“看到”图像:

    由于大型语言模型(LLMs)具有很强的记忆能力,它们在处理信息时不需要多次“看到”同一个输入。这意味着,一旦模型学习了某个图像的信息,它就能够记住这些信息,而不需要像传统的视觉识别网络那样通过多个训练周期(epochs)多次学习同一个数据集。这种强记忆能力可能会限制模型在学习视觉词汇时的灵活性。

基于此,作者提出了一个问题:是否存在一种策略可以简化并有效增强视觉词汇?

简化并增强视觉词汇的策略可能包括创建更高效的模型架构,使用更先进的训练技术,或者开发新的算法来更好地整合视觉信息和文本信息,而不会受到上述限制的影响。

这篇论文提出了一个名为 Vary 的方法,它是一个高效且用户友好的方法,用于解决上述问题。Vary 的灵感来源于传统大型语言模型(LLMs)中文本词汇扩展的方式,即,当将一个英语 LLM 转移到另一种外语(如中文)时,需要扩展文本词汇以提高新语言下的编码效率和模型性能。直观地说,对于视觉分支,如果我们向模型输入“外语” 图像(也就是没有见过的图像或者说不理解的图像),也需要扩大视觉词汇。

Vary,也就是一个扩大 LVLM 的 Vision vocabulary 的方法:

  • 第一阶段:生成一个新的视觉词汇表:使用 vocabulary network 和一个 tiny decoder-only transformer 来通过自回归产生需要的 vocabulary

    自回归的方式就是通过预测下一个 token 的方式来训练词汇模型,因为基于自回归的生成词汇的过程可能比基于对比学习的方式(如CLIP)更适合密集感知任务,原因有两个:

    ① 预测下一个 token 的方式可以允许视觉词汇压缩更长的文本

    ② 这种方式可以使用的数据格式更为多样,例如带有提示的VQA数据。在准备好新的视觉词汇后,我们将其添加到传统的 LVLMs 中以引入新特性。在此过程中,冻结了新旧词汇网络,以避免视觉知识被覆盖。

  • 第二阶段:整合新旧词汇表:通过将新产生的 vocabulary 和原来的 CLIP vocabulary 结合起来,让 LVLM 能很快的获得新的特征,在扩大视觉词汇后,LVLM 可以实现更细粒度的视觉感知,此外,作者提供了产生合成数据的方法

效果:

  • 相比于 BLIP-2, MiniGPT4 和 LLaVA, Vary 能在保持 LVLM 原始性能的同时,提供更好的精细感知和理解能力
  • Vary 能够在文档理解(document parsing,包括 OCR 或 markdown 转换),在 DocVQA 上获得了 78.2% ANLS,在 MMVet 上获得了 36.2% ANLS

在这里插入图片描述

二、方法

在这里插入图片描述

Vary 的整体结构如图 2 所示:

  • Vary-tiny:生成新的 vision vocabulary:

    • 由 vocabulary network 和 tiny OPT-125M 组成,在两个模块中间使用了线性层来进行通道维度对齐
    • 因为 Vary-tiny 主要是用于精细粒度的感知,所以 Vary-tiny 没有 text 输入分支
    • 作者期望 vision vocabulary network 是能够处理文档、表格等人造图像来弥补 CLIP 的不足,但同时又不能是 CLIP 的噪声,所以在训练的时候,是将人工造的文档或表格数据作为 positive samples,自然图片作为 negetives samples 来训练 vary-tiny 的
  • Vary-base:使用新的 vision vocabulary:

    • 在训练完 vary-tiny 之后,使用训练好的 vocabulary network 加到更大的模型上来构建 vary-base,如图 2 下半部分,新的和旧的 vocabulary network 的 input embedding layer 是独立的,在送入 LLM 之前会合并起来,在这个阶段,新旧 vocabulary network 的参数都是冻结的,其他模块的参数都是放开的

2.1 生成 new vision vocabulary

2.1.1 new vocabulary network

在这里插入图片描述

作者使用经过 SAM 预训练的 ViTDet 的 image encoder(base scale)作为 new vocabulary network 的主要部分

由于 SAM-base 的输入分辨率是 1024x1024,输出是 16x 下采样后的,最后一层的输出大小是 64x64x256,没法和 CLIP-L (256x1024 for NxC)的输出匹配上

所以,作者在 SAM 的最后一层后面加了两层卷积层,如图 3 所示,第一层卷积核大小为 3,将特征转换为 32x32x512,第二层卷积和第一层一样,将输出进一步转换成 16x16x1024,这样,就可以将输出和 CLIP-VIT 的 256x1024 对齐了

2.1.2 Data engine in the generating phrase

1、文档数据

作者选择高分辨率的文档图像-文本对作为新视觉词汇预训练的主要 positive 数据集,因为密集的OCR可以有效验证模型的细粒度图像感知能力。

据作者所知,目前没有公开可用的包含英文和中文文档的数据集,因此作者自己创建了一个。

作者首先从 arXiv 和 CC-MAIN-2021-31-PDF上来收集英文部分的 PDF 风格文档,并从互联网上的电子书中收集中文部分。

然后,使用 PyMuPDF 的 fitz 提取每个 PDF 页面的文本信息,并同时通过 pdf2image 将每页转换成 PNG 图像。在此过程中,作者构建了100万中文和100万英文的文档图像-文本对进行训练。

2、表格数据

作者发现当前的 LVLMs(大型视觉语言模型)在图表理解方面不是很好,尤其是中文图表,所以选择它作为另一个需要“编入”新词汇的主要知识。

对于图表图像-文本对,作者选择 matplotlib 和 pyecharts 作为渲染工具。对于 matplotlib 风格的图表,作者分别构建了25万中文和英文的图表。而对于 pyecharts,作者分别构建了50万中文和英文的图表。此外,作者将每个图表的文本真实值转换为 python 字典形式。图表中使用的文本,例如标题、x轴和y轴,是从互联网上下载的自然语言处理(NLP)语料库中随机选取的。

3、自然数据(作为负样本)

对于 CLIP-VIT 擅长的自然图像数据,作者需要确保新引入的词汇不会造成噪音。因此,作者构建了负面自然图像-文本对,以使新词汇网络在看到自然图像时能够正确编码。作者从COCO数据集[22]中提取了12万张图像,每张图像对应一段文本。

文本部分是随机选自以下句子:“这是一张自然图像”;“这里有一张自然图片”;“这是一张自然照片”;“这是一张自然图像”;“那是来自大自然的一张照片”。

2.1.3 输入的格式

作者使用自回归的方式,使用 image-text pairs 来训练 vary-tiny 的所有参数

输入的形式和现有的 LVLM 一致:

  • image token 和 text token 被打包起来,使用前缀区分
  • “” 和 “” 用来界定图像数据在输入序列中的位置。这样做可以让模型知道哪部分是图像,哪部分是文本。这些数据被输入到一个叫做OPT-125M的模型中,这个模型可以处理长达4096个令牌(token)的序列。这里的令牌可以是图像的一部分,也可以是文本的一部分。
  • 在训练过程中,尽管输入包含图像和文本,Vary-tiny 模型的输出仅为文本。此外,文本的结束标记符号是 “/s”,也就是 eos token,这告诉模型一段文本何时结束。

2.2 扩大 vision vocabulary

2.2.1 Vary-base 的结构

在完成词汇网络的训练之后,将其引入到语言-视觉多模态模型(LVLM)——Vary-base 中。

新的视觉词汇与原始的 CLIP-VIT 是并行的,这两个视觉词汇都拥有各自的输入嵌入层,即一个简单的线性层。

如图2所示,线性层的输入通道是1024,输出是2048,确保在拼接后图像令牌的通道数为4096,这正好与大型语言模型(LLM)的输入对齐(无论是Qwen-7B还是Vicuna-7B)

2.2.2 Data engine

作者通过下面这些方法来进行数据扩充

1、Latex 渲染的方式

除过上面收集的文档,还需要一些公式或表格数据,作者使用 latex 渲染的方式来生成一些相关数据

  • 首先,作者收集了一些 arxiv 上的 .txt 源文件

  • 然后,使用正则表达式提取了表格、数学公式和纯文本。

    在提取表格和公式的应用场景中,正则表达式可以这样工作:提取表格:在LaTeX文档中,表格通常使用\begin{table}和\end{table}标签包围。正则表达式可以被设计来搜索这些特定的标签及其之间的所有内容,从而提取整个表格。提取公式:类似地,数学公式在LaTeX中通常被 \begin{equation}和\end{equation}或者 . . . ... ...(对于内联公式)和 . . . ... ...或者[…](对于展示公式)所包围。正则表达式可以匹配这些模式来提取公式。

  • 最后,使用 pdflatex 重新渲染这些内容。作者收集了10多个模板来执行批量渲染。此外,每个文档页面的文本真实内容转换 为mathpix markdown 风格,以统一格式。通过这个构建过程,获得了50万页英文页面和40万页中文页面。一些样本展示在图4中。

    pdflatex是一个用于将LaTeX文档转换成PDF格式的命令行工具。LaTeX是一种基于TeX的排版系统,广泛用于生成科学和数学文献的复杂和高质量的文档。当你编写了一个LaTeX文档(通常是一个.tex文件)后,你需要通过一个编译过程将其转换成可读的文档,通常是PDF格式。pdflatex正是用于这种转换的工具之一。

在这里插入图片描述

2、语义关联图表渲染

在 2.1.2 节中,批量渲染图表数据来训练新的词汇网络。然而,这些渲染图表中的文本(标题、x轴值和y轴值)相关性较低,因为它们是随机生成的。这个问题在词汇生成过程中并不是问题,因为生成任务只希望新的词汇能够有效压缩视觉信息。然而,在Vary-base的训练阶段,由于解冻了LLM,希望使用更高质量(内容强相关)的数据进行训练。因此,使用 GPT-4[32] 来生成一些使用相关语料库的图表,然后我们利用高质量的语料库额外渲染了20万个图表数据用于Vary-base训练。

3、通用数据

Vary-base 的训练过程遵循流行的 LVLMs,例如 LLaVA[25],包括预训练和 SFT 阶段。与 LLaVA 不同的是,作者冻结了所有的词汇网络并解冻了输入嵌入层和 LLM,这更像是纯 LLM 的预训练设置。

作者使用自然图像-文本对数据来向 Vary-base 介绍通用概念。这些图像-文本对是从 LAION-COCO[37] 中随机提取的,数量为 400万。在 SFT 阶段,作者使用 LLaVA-80k 或 LLaVA-CC665k[24] 以及 DocVQA[29] 和 ChartQA[28] 的训练集作为微调数据集。

2.2.3 对话格式

当使用 Vicuna-7B 作为 LLM 时,对话的格式是和 Vicuna v1 [8] 相同的:

  • USER: “” “texts input”
  • ASSITANT: “texts output”

因为 Vicuna 处理中文很慢,所示使用 Qwen-7B [2] 作为 LLM 来处理中文,当使用 Qwen-7B [2] 处理中文的时候,对话格式参考的是 LLaVA-MPT [25, 41]:

  • <|im_start|>user: “” “texts input”<|im_end|> <|im_start|>assistant: “texts output” <|im_end|>.

三、效果

3.1 数据集

作者使用了多个数据集进行了测试:

  • 作者构建的 document-level OCR 测试集,主要是为了测试密集视觉感知能力:包括纯 OCR 和 markdown 转换任务
    • 纯 OCR 任务的测试集包括 100 张中英文数据,是随机从 arxiv 和 ebook 上抽取的
    • markdown 转换任务重,测试集包括 200 pages,其中 100 包括表格,另外 100 包括数学公式
  • DocVQA[29] 和 ChartQA [28],主要测试下游任务上的能力
  • MMVet[51],测试整体模型的效果

document parsing 测评指标:

  • Normalized Edit Distance
  • F1-Score
  • precision
  • recall

DocVQA, ChartQA, 和 MMVet 使用原来的测评

训练细节:

  • 对于词汇扩充任务,作者训练 vary-tiny 的全部参数,使用的 batch=512,epoch=3,optimizer=AdamW(cosine 退化),lr=5e-5
  • 在训练 vary-tiny 的时候,作者冻结了 new 和 vanilla(CLIP)的 vision vocabulary network,优化的是 input embedding layers 和 LLM
  • pretrain 预训练的时候 lr=5e-5,训练 SFT 的时候 lr=1e-5,预训练和 SFT 时 batch=256,epoch=1

归一化编辑距离:

  • OCR(光学字符识别)中的归一化编辑距离(Normalized Edit Distance,也称为Levenshtein距离)是一种衡量两个字符串相似度的方法。它通过计算将一个字符串转换成另一个字符串所需要的最少单字符编辑操作次数来实现。单字符编辑操作包括插入、删除和替换。

  • 编辑距离(Levenshtein距离):这是一个衡量两个字符串差异的指标,通过计算一个字符串转换成另一个字符串所需要的最小编辑操作数。这些操作通常包括:

    • 插入:在一个字符串中插入一个字符。
    • 删除:从一个字符串中删除一个字符。
    • 替换:将一个字符串中的一个字符替换成另一个字符。
  • 归一化编辑距离是将编辑距离除以两个字符串中较长的那个的长度,使得得到的值在0到1之间。这样可以消除字符串长度对比较结果的影响,让结果更加标准化。归一化编辑距离可以定义为:

    归一化编辑距离 = 编辑距离 max ⁡ ( 字符串1的长度 , 字符串2的长度 ) \text{归一化编辑距离} = \frac{\text{编辑距离}}{\max(\text{字符串1的长度}, \text{字符串2的长度})} 归一化编辑距离=max(字符串1的长度,字符串2的长度)编辑距离

  • 归一化编辑距离的值越接近 0,表示两个字符串越相似;值越接近1,则表示两个字符串差异越大。

  • 在OCR系统中,归一化编辑距离常用来评估OCR输出和实际文本之间的差异,以此来衡量OCR系统的准确性。如果OCR输出的文本和实际文本的归一化编辑距离很小,那么可以认为OCR系统具有较高的识别准确率。反之,如果归一化编辑距离较大,则说明OCR系统可能在文本识别上存在较多错误。

3.2 图像细粒度感知能力

作者通过密集文本识别能力来衡量 Vary 的细粒度感知性能。

如表1所示,Vary-tiny 通过视觉词汇生成过程,集合了中文和英文的密集OCR能力:

  • 它在中文和英文文件(纯文本)OCR上分别实现了0.266和0.197的编辑距离,这证明了新视觉词汇具有良好的细粒度文本编码能力。
  • 对于Vary-base,它在英文纯文本文件上可以达到与 nougat(一种特殊的文档解析模型)相当的性能。

此外,使用不同的提示(例如,将图像转换为markdown格式),Vary-base 可以实现文档图像到 markdown 格式的转换。

值得注意的是,在这样的任务中,Vary-base(在数学和表格平均值上具有0.181的编辑距离和81.10%的F1得分)在某种程度上比nougat(平均0.245的编辑距离和79.97%的F1得分)要好,这可能是由于7B LLM(Qwen)超强的文本纠正能力。

所有上述结果表明,通过扩展视觉词汇,新的LVLM可以提升其细粒度感知性能。
在这里插入图片描述

3.3 下游任务

作者在 DocVQA [29] 和 ChartQA [28] 两个下游视觉问答(VQA)任务上测试了性能提升。

作者使用了额外的提示:“使用单个单词或短语回答以下问题:”[24],以便模型输出简短且精确的答案。

如表 2 所示,Vary-base(以Qwen-7B作为大型语言模型LLM)在DocVQA上,基于LLaVA-80k [25] 的 SFT(特定任务微调)数据,可以达到 78.2%(测试集)和 76.3%(验证集)的 ANLS 得分。

使用 LLaVA-665k [24] 数据进行 SFT,Vary-base 在 ChartQA 上的平均性能可以达到 66.1%。

在这两个具有挑战性的下游任务上的表现可与 Qwen-VL [4]相媲美,甚至更好,这证明了本文提出的视觉词汇扩展方法对于下游任务也是有前景的。

在这里插入图片描述

3.4 通用效果

作者通过 MMVet [51] 基准测试来监控 Vary 的整体性能。

如表3所示,使用相同的大型语言模型(Vicuna-7B)和特定任务微调数据(LLaVA-CC665k),Vary的性能提升了 2.4%(从 30.5% 提升至 32.9%),这证明了本文的数据和训练策略没有损害模型的通用能力。

此外,结合 Qwen-7B 和 LLaVA-80k 的 Vary 可以达到 36.2% 的性能,进一步证明了我们扩大视觉词汇量的有效性。

在这里插入图片描述

3.5 其他效果展示

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

四、代码

4.1 数据预处理

本文的方法是多模态的方法,所以有两种数据预处理方式,即文本和图片分别预处理,在预处理之前呢,需要将训练的图片和标注结果先进行格式处理,处理到如下的格式(来自作者github issue 中):

在这里插入图片描述

def __getitem__(self, i) -> Dict[str, torch.Tensor]:
     # data = self.list_data_dict[i]
     data = copy.deepcopy(self.list_data_dict[i])

     if isinstance(data, dict):
         if 'image' in data:
             image_path = self.list_image_path[i]
             image_file = data['image']

             try:
                 image = Image.open(image_path + image_file).convert('RGB')
             except:
                 print(f'cannot identify image file {image_path + image_file}.')
                 return self.__getitem__(0)

             try:
                 image, image_1 = self.image_processor(image)
             except:
                 print(f'image {image_file} are broken or grayscale! we thus select 0-th sample instead!')
                 return self.__getitem__(0)

         conversations = self.multimodal_processor([data["conversations"]])

     else:
         conversations = [data]

     # align with fastchat & llava here, put the conversation into a list for tokenization
     data_dict = self.token_processor(conversations)
     data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
     
     if isinstance(data, dict) and 'image' in data:
         data_dict['image'] = [image]
         data_dict['image_high'] = [image_1]
     else:
         crop_size = self.multimodal_cfg['image_processor'].crop_size
         data_dict['image'] = [torch.zeros(3, crop_size['height'], crop_size['width'])]
         data_dict['image_high'] = [torch.zeros(3, 1024, 1024)]
     return data_dict

图像预处理:

  • clip:224x224
  • ViT:1024x1024

文本 prompt 预处理:

conversation:

处理之前:

[[{'from': 'human', 'value': '<image>Provide the ocr results of this image'}, {'from': 'gpt', 'value': '5. 二次函数$$ y=ax^{2} $$的图象经过点(2,-2)\n(1)求这个函数的表达式;\n(2)当x为何值时,函数y随x的增大而增大?'}]]

处理之后:

[[{'from': 'human', 'value': '<img><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad></img>Provide the ocr results of this image'}, {'from': 'gpt', 'value': '5. 二次函数$$ y=ax^{2} $$的图象经过点(2,-2)\n(1)求这个函数的表达式;\n(2)当x为何值时,函数y随x的增大而增大?'}]]

data_dict = self.token_processor(conversations):

data_dict:{
'input_ids': tensor([151644,   8948,    198,   2610,   1265,   1795,    279,  11221,  15516,
        323,  10339,    697,  11253,    304,   7716,     13, 151645, 151644,
        872,    198, 151857, 151859, 151859, 151859, 151859, 151859, 151859,
        151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
        151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
        151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
        151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
        151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
        151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
        151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
        151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
        151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
        151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
        151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
        151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
        151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
        151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
        151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
        151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
        151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
        151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
        151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
        151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
        151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
        151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
        151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
        151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
        151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
        151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
        151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
        151859, 151859, 151859, 151859, 151859, 151859, 151859, 151858,     90,
         40581,     92, 151645, 151644,  77091,    198,     20,     13,    220,
        105935,  32804,  14085,    379,  71663,  47822,     17,     92,  26107,
          9370,  28029,  46423, 101897,  27442,      7,     17,   4999,     17,
           340,      7,     16,      8,  30918,  99487,  32804,   9370, 102124,
         28330,    280,      7,     17,      8,  39165,     87, 104499,  25511,
         13343,     11,  32804,     88,  99411,     87,   9370, 108696,  68536,
        108696,     30, 151645]), 
 'labels': tensor([  -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
          -100,   -100,   -100,   -100,   -100,   -100,     20,     13,    220,
        105935,  32804,  14085,    379,  71663,  47822,     17,     92,  26107,
          9370,  28029,  46423, 101897,  27442,      7,     17,   4999,     17,
           340,      7,     16,      8,  30918,  99487,  32804,   9370, 102124,
         28330,    280,      7,     17,      8,  39165,     87, 104499,  25511,
         13343,     11,  32804,     88,  99411,     87,   9370, 108696,  68536,
        108696,     30, 151645])}

怎么处理的呢:

conversation 会被处理如下,然后提取每个文字对应的 token id,得到 input_ids,然后 label 就是 input_id

conversations: [
'<|im_start|>system\nYou should follow the instructions carefully and explain your answers in detail.<|im_end|>
<|im_start|>user\n<img><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad>...<imgpad><imgpad></img>Provide the ocr results of this image.<|im_end|>
<|im_start|>assistant\n5. 二次函数$$ y=ax^{2} $$的图象经过点(2,-2)\n(1)求这个函数的表达式;\n(2)当x为何值时,函数y随x的增大而增大?<|im_end|>']

然后每个字母和 id 就是这样对应的:

  • <|im_start|>:151644
  • system:8948
  • \n:198
  • You:2610
  • should:1265

可以看到 conversation 中有三部分:

  • system
  • user
  • gpt

只有 gpt 回答部分需要学习,前两部分是不需要学习的,所以会将 input_ids 中这两部分的 token id 全置换为 -100,然后作为 label 用于 LLM 模型的学习。

4.2 常用模块解释

1、AutoTokenizer

这个类用于加载预训练的tokenizer(词元化器)。Tokenizer负责将原始文本转换成模型能够理解的格式,即将句子分解成词或子词单元(tokens),这些单元可以是单词、字母或者是词根等。此外,它还负责将这些tokens转换为模型需要的数字ID,并且可以自动处理添加特殊token,比如序列开始和结束标记。

2、AutoModelForCausalLM

这个类用于加载预训练的因果语言模型(Causal Language Model)。因果语言模型是一种自回归模型,它基于给定的一系列词元(例如一个句子中的单词)来预测下一个词元。这种模型常用于生成文本,如聊天机器人、文本补全、故事生成等。“Causal”(因果的)一词意味着模型的预测只依赖于先前的词元,而不是未来的词元。

简而言之,使用 AutoTokenizer 可以准备数据,然后使用 AutoModelForCausalLM 来生成或者继续生成文本

3、CLIPImageProcessor
不直接提取图像特征,而是准备图像数据,使其能够被 CLIP 模型正确处理。它通常执行的操作包括调整图像大小、归一化像素值等,以匹配训练时使用的格式。这是数据预处理的一个步骤,它确保输入数据与模型在训练时接收的数据格式相同。

具体提取图像特征的是 CLIP 模型的图像编码器部分。在 CLIP 模型中,有两个主要的组件:

图像编码器(Image Encoder):它负责将图像转换成嵌入向量(即特征表示)。这通常是一个预训练的卷积神经网络(CNN)或者变换器(Transformer)架构,它将输入图像转换为一维的特征向量。

文本编码器(Text Encoder):它负责将文本输入转换成嵌入向量,使其与图像嵌入处于相同的嵌入空间。

CLIP 模型的核心思想是同时训练图像和文本编码器,使得它们能够将图像和文本映射到共同的嵌入空间中,从而可以比较图像和文本的相似度。

CLIPImageProcessor {
  "crop_size": {
    "height": 224,
    "width": 224
  },
  "do_center_crop": true,
  "do_convert_rgb": true,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "feature_extractor_type": "CLIPFeatureExtractor",
  "image_mean": [
    0.48145466,
    0.4578275,
    0.40821073
  ],
  "image_processor_type": "CLIPImageProcessor",
  "image_std": [
    0.26862954,
    0.26130258,
    0.27577711
  ],
  "resample": 3,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "shortest_edge": 224
  }
}

经过处理后,会得到 224x224 的图片

4、image_processor_high

经过处理后,得到 1024x1024 的图片

4.3 模型训练流程

1、经过处理后的图片(224 和 1024)经过 varyQwenModel 进行特征提取,并经过 projector 后送入 LLM,得到输出

这两组图片(224x224 和 1024x1024)都会送入代码,分别提取特征(使用 varyQwenModel),一个是保证 CLIP 原始的输出,一个是新的更细粒度的输出

CLIP 的特征抽取器是 :

# vision_tower = getattr(self, 'vision_tower', None)
vision_tower
CLIPVisionModel(
  (vision_model): CLIPVisionTransformer(
    (embeddings): CLIPVisionEmbeddings(
      (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
      (position_embedding): Embedding(257, 1024)
    )
    (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-23): 24 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=1024, out_features=4096, bias=True)
            (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          )
          (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
    (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
)

这个特征抽取器的输出如下,使用的是 image_forward_out.hidden_states[vision_select_layer],也就是 image_forward_out.hidden_states[-2],为什么要用倒数第二个隐层的输出,我查过后的解释在下边(不知道是否客观)

这个特征抽取器的最终使用的就是倒数第二个隐藏层的输出,倒数第二个隐藏层的特征维度为 [1, 257, 1024],应该是包含了一个 [CLS] token(第一个),所以取了除第一个 token 特征之外的所有特征,维度为 [1, 256, 1024]。

for image in images:
     with torch.set_grad_enabled(False):
     	 # CLIP 的输出
         image_forward_out = vision_tower(image[0], output_hidden_states=True)
         # image_forward_out.keys(): odict_keys(['last_hidden_state', 'pooler_output', 'hidden_states'])
         # image_forward_out['last_hidden_state'].shape=torch.Size([1, 257, 1024]),这通常指模型最后一层的输出,对于图像来说,这可能是最后一个卷积层或者是transformer层的输出特征表示。这个输出通常用于下游任务,如图像分类、目标检测等。
         # image_forward_out['pooler_output'].shape=torch.Size([1, 1024]),在某些模型中(如BERT),pooler_output是经过池化操作的最后一层的输出,通常用于获取整个输入图像的单个固定大小的表示。在图像模型中,这可能是经过一系列卷积层和池化层后得到的全局特征表示。
         # len(image_forward_out.hidden_states)=25, image_forward_out['hidden_states'][0].shape=torch.Size([1, 257, 1024]),这是模型中所有层的输出的集合。对于transformer模型,这将包含每个attention层的输出。这些隐藏状态可以用于深入分析模型的行为,或者在某些高级应用中,比如特征提取或者转移学习。
         select_hidden_state = image_forward_out.hidden_states[vision_select_layer] #torch.Size([1, 257, 1024]), 在Transformer模型中,通常会有一个特殊的标记,比如[CLS]标记,它用于聚合序列的全局信息,经常用于分类任务。如果输入序列本身有256个元素,加上这个特殊的标记,就会得到257个元素。在视觉Transformer中,这个额外的元素可能是一个代表整个图像的全局信息的标记。
         image_feature = select_hidden_state[:, 1:]  # torch.Size([1, 256, 1024])
     with torch.set_grad_enabled(False):
         cnn_feature = vision_tower_high(image[1])
         cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # torch.Size([1, 256, 1024])
     image_features_1.append(image_feature)
     image_features_2.append(cnn_feature)

这里的 (0-11): 12 x Block 表示 blocks 这个 ModuleList 包含了从索引 0 到 11 的12个 Block 实例,每一个 Block 都有相同的结构。每个 Block 包含了两个 LayerNorm 层,一个 Attention 模块(包含 qkv 和 proj 线性层),以及一个多层感知器(MLPBlock),后者包含两个线性层和一个GELU激活函数。

在自然语言处理(NLP)中,tokenize这个术语通常指的是将文本分割成更小的单元,这些单元可以是单词、子词或字符。这个过程对于准备文本数据以供深度学习模型如GPT(生成预训练变换器)使用是必不可少的。

input_ids是tokenize过程的产物之一。当文本被分割成tokens之后,每个token会被映射到一个唯一的数字ID。这些数字ID是预先定义的,并且存储在模型使用的词汇表中。这个映射允许模型以数值形式处理文本数据,因为模型不能直接理解原始文本。

例如,考虑这个句子:“Hello, world!”。tokenize过程可能会将其分割成[“Hello”, “,”, “world”, “!”]。然后,这些tokens会根据词汇表映射到对应的input_ids:[7592, 16, 2080, 328]。这些ID是模型可以理解的形式,因为模型的嵌入层会使用这些ID来查找每个token的嵌入表示。

经过 CLIP vision tower 后,通常选择其 hidden_states 的倒数第二层来作为图像的特征,原因如下:

在深度学习模型中,尤其是在类似于Transformer的结构中,每一层都会输出一个中间表示的特征集。这些特征在不同的层捕获了不同层次的信息。通常,最后一层的输出被认为是最抽象的,包含了输入数据的高级特征表示,这对于许多任务是有用的。然而,实践中发现,最后一层的表示有时可能过于特化于模型在训练期间所执行的任务。

从倒数第二层(或者其他非最后一层)获取特征的动机可能包括:

  • 泛化能力:倒数第二层(或其他中间层)的特征可能比最后一层的特征更具泛化能力。对于某些任务,中间层的特征可能更适合迁移到新的、未见过的任务或数据集上。

  • 过拟合的减少:最后一层的特征可能对训练数据过度拟合,尤其是在模型极度复杂或训练数据有限的情况下。使用倒数第二层的特征可能有助于减少这种过拟合。

  • 信息的丰富性:最后一层之前的层可能保留了更多的信息,因为最后一层可能会丢弃对当前训练任务不重要的信息。在一些情况下,这些“不重要”的信息可能对于新任务来说是有用的。

  • 实验结果:在实际应用中,研究人员可能会发现通过实验,倒数第二层的特征在特定的下游任务上性能更好。这可以是经验上的选择。

  • 计算效率:在某些情况下,最后一层可能会进行额外的操作(如池化或特殊的注意力机制),这可能不是必要的。使用倒数第二层的输出可以避免这些可能不必要的计算,从而提高效率。

更细粒度的 SAM 的 ViT 的特征抽取器:

# vision_tower_high = getattr(self, 'vision_tower_high', None)
vision_tower_high
ImageEncoderViT(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  )
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLPBlock(
        (lin1): Linear(in_features=768, out_features=3072, bias=True)
        (lin2): Linear(in_features=3072, out_features=768, bias=True)
        (act): GELU(approximate='none')
      )
    )
  )
  (neck): Sequential(
    (0): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): LayerNorm2d()
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (3): LayerNorm2d()
  )
  (net_2): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (net_3): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
)

这个特征抽取器输出的特征维度为 torch.Size([1, 1024, 16, 16]),拉平后特征维度为 torch.Size([1, 256, 1024])

看到这里 SAM 的 ViT 原本的输出维度为 64x64x256,后面的 net_2 和 net_3 是作者加入的层,主要为了改变特征维度

with torch.set_grad_enabled(False):
      cnn_feature = vision_tower_high(image[1]) # torch.Size([1, 1024, 16, 16])
      cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # torch.Size([1, 256, 1024])

两组图片特征会分别经过 projector,然后 concat 起来,送入 LLM 以供学习

if type(images) is list: 
      image_features_1 = [self.mm_projector(image_feature) for image_feature in image_features_1] # Linear(in_features=1024, out_features=1024, bias=True), [1, 256, 1024]
      image_features_2 = [self.mm_projector_vary(image_feature) for image_feature in image_features_2] # Linear(in_features=1024, out_features=1024, bias=True), [1, 256, 1024]
      image_features = [torch.cat((image_feature[0], image_feature[1]), dim=-1) for image_feature in zip(image_features_1, image_features_2)] # image_features[0].shape=torch.Size([1, 256, 2048])

2、图像特征如何送入 LLM

  • input_ids:tokenized 分词后的 id,torch.Size([1, 679]),序列的长度是679个令牌。换句话说,输入文本在分词后生成了679个令牌,每个令牌都被转换成了对应的整数ID
  • input_embedding:torch.Size([1, 679, 2048]), 将 input_ids 经过 word to embedding 映射,给每个 id 得到一个 2048 维的特征, self.wte = nn.Embedding(self.vocab_size, self.embed_dim),Embedding(151860, 2048),if inputs_embeds is None: inputs_embeds = self.wte(input_ids)
  • image_features:image_features[0].shape=torch.Size([1, 256, 2048])

如何从 input_ids ([1,679]) 得到 input_embedding ([1, 679, 2048]):使用 nn.Embedding

nn.Embedding 类在 PyTorch 中是用来创建一个简单的查找表,该表将每个 token 的整数索引映射到一个高维空间的向量。当你创建一个 Embedding(151860, 2048) 实例时,你实际上在内存中初始化了一个矩阵,其中有 151860 行(对应于词汇表中的每个 token)和 2048 列(对应于每个 token 的嵌入向量的维度)。

内部结构可以想象成一个大型的矩阵,其形状为 (vocab_size, embed_dim),即 (151860, 2048)。当你用一个 token 的索引来索引这个嵌入层时,它会返回该索引对应的行,也就是该 token 的嵌入向量。这个过程通常称为“查找”。

在训练开始时,这个矩阵的值通常是随机初始化的,然后在神经网络的训练过程中通过反向传播算法进行调整。目标是使得这些嵌入向量在完成特定任务(如分类、翻译、情感分析等)时能够捕捉到词汇之间的语义关系。

例如,如果你有一个包含 token 索引的 PyTorch 张量 input_ids,你可以通过调用嵌入层 self.wte(input_ids) 来获取这些 token 的嵌入向量。这个操作在内部等价于从嵌入矩阵中选择特定行。每个 token id 对应嵌入矩阵中的一行,而这行的内容就是该 token 的嵌入向量。

这里有三个 token ID 要注意:

  • 图像开始的 token id:151857
  • 图像块儿的 token id:151859
  • 图像结束的 token id:151858

将上面得到的所有 token embedding 特征(图像开始之前的 + 图像的 + 图像结束后的),假设共 679x2048 维,也就是共 679 个 token

我们看到这里,之前的图像都是使用 151859 这个 token id 来表示的,所以编码后的 token embedding(每个 embedding 的维度为2048),我们在前面通过 vision_tower 计算出了 image feature,维度为(256x2048),所以这里将 151859 对应的 image token 使用计算到的 image feature 替换,维度是一样的 2048,替换后的 cur_input_embeds 就是包括图像特征和文本特征的 embedding 了。

得到了 input_ids 以及包括 system+user+gpt 所有 token 的 embedding 特征后,然后代码里边将这些东西都送入了 varyQwenModel,也就是如下这样的,这里把多模态的 token embedding 整个过一遍 QwenModel 的 forward,也就是这里就算是输入了语言大模型

# # 这里进入的是 Vary-master/vary/model/llm/qwen/modeling_qwen.py(543)forward()
return super(varyQwenModel, self).forward(
     input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
     inputs_embeds=inputs_embeds, use_cache=use_cache,
     output_attentions=output_attentions, output_hidden_states=output_hidden_states,
     return_dict=return_dict)

大模型的结构是这样的,也就是会将多模态 token embedding 的特征送入多层网络组成的 QWen 模型中,使用预训练好的模型来得到输出 output,每层的输出都是 [1, 679,2048],第一层的输入是上面得到的多模态 token embedding,之后每层的输入都是上一层的输出,一共进行 24 次,最终得到的特征维度还是 [1, 679,2048]

注意:这里的大语言模型的参数是放开训练的,不会冻结!

self.h
ModuleList(
  (0-23): 24 x QWenBlock(
    (ln_1): RMSNorm()
    (attn): QWenAttention(
      (c_attn): Linear(in_features=2048, out_features=6144, bias=True)
      (c_proj): Linear(in_features=2048, out_features=2048, bias=False)
      (rotary_emb): RotaryEmbedding()
      (attn_dropout): Dropout(p=0.0, inplace=False)
    )
    (ln_2): RMSNorm()
    (mlp): QWenMLP(
      (w1): Linear(in_features=2048, out_features=5504, bias=False)
      (w2): Linear(in_features=2048, out_features=5504, bias=False)
      (c_proj): Linear(in_features=5504, out_features=2048, bias=False)
    )
  )
)

经过了大语言模型的处理后,这些所有的 token 会经过一个 self.lm_head 头 (Linear(in_features=2048, out_features=151860, bias=False),将 torch.Size([1, 679, 2048]) 映射到 torch.Size([1, 679, 151860])

所以 LLM 的参数和 self.lm_head 头的参数都会参与训练

3、计算 loss

  • label:[1, 679]
  • -100 是不需要参与 loss 计算的,只有标签中 gpt 的 value 内容(也就是标注内容)需要参与 loss 计算
# label
tensor([[  -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,  54623,  46944, 101675,
          82699,     11,  32555, 104121, 113759,  99319,  45861, 105706,     19,
           6226,   5373,     18,   6226,   5373,     17,     13,     20,   6226,
            624, 100213,  17447, 108112,    715, 100380,  26939, 100213,  63109,
           9370,    198, 100764,  48921,  49567,     13,    715,  29524,  28029,
             24,     13,     16,     13,     16,     18,     11,  60726,  54623,
          43268,  37474,  14085,  14137,     28,     19,   6226,  26107,     11,
         101889,  23031,  27442,     32,  17714, 100213,    198,  63109,   5373,
             18,   6226,  45861,  17714,  99369,  66569,  54623, 100213, 108718,
             11,  87256,  23031,  27442,     33,  17714, 100213,  63109,   5373,
             17,     13,     20,   6226,  45861,  17714,    198,  99369,  66569,
          54623, 100213, 108718,     11,  77540, 108718,  48921,  38109,  34204,
          27442,     34,     11,  54926,  36885,   1706,   5373,   4897,     13,
           1124,  55114,  19360,  99486,    198,  31838,  30534,  54623,   9370,
         101675,  82699,     13, 151645]], device='cuda:0')
  • 预测的结果

然后求 crossentropy loss

{'loss': 1.0492, 'learning_rate': 3.094059405940594e-08, 'epoch': 0.0}  

4.4 模型推理流程

数据处理:

1、prompt 处理

  qs = 'Provide the ocr results of this image.'

  if use_im_start_end:
      qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN  + '\n' + qs
  else:
      qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

    # qs: '<img><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad></img>\nProvide the ocr results of this image.'

2、prompt 转换 conversation

# '<|im_start|>system\nYou should follow the instructions carefully and explain your answers in detail.<|im_end|><|im_start|>user\n<img><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad><imgpad></img>\nProvide the ocr results of this image.<|im_end|><|im_start|>assistant\n'

3、conversation 分词得到 input_ids

input_ids 的维度为:292,推理的时候图像分块 256,prompt 都是一样的,所以推理的时候 input_ids 都是一样的

[151644, 8948, 198, 2610, 1265, 1795, 279, 11221, 15516, 323, 10339, 697, 11253, 304, 7716, 13, 151645, 151644, 872, 198, 151857, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151858, 198, 60424, 279, 297, 5082, 3059, 315, 419, 2168, 13, 151645, 151644, 77091, 198]

模型推理:流式推理,需要多次推理才能得到最终的输出句子结果

这里流式推理调用的是 transformer 库的 generate 方法,image 参数是通过 **kwargs 给入的,这也是多模态模型比语言模型多的一个输入

generate 的这里(sample)会调用 Vary 模型:

# forward pass to get next token
outputs = self(
    **model_inputs,
    return_dict=True,
    output_attentions=output_attentions,
    output_hidden_states=output_hidden_states,
)
#  outputs.keys()=odict_keys(['logits', 'past_key_values'])
'''
outputs.logits
tensor([[[ 3.5938,  1.4141,  2.6719,  ..., -3.1562, -1.8438, -1.7812],
         [ 3.5938,  1.4062,  2.6250,  ..., -3.1719, -1.8516, -1.7891],
         [-2.7656, -0.9883, -2.4219,  ..., -2.9688, -3.5781, -3.5938],
         ...,
         [ 4.8438,  5.0938,  2.7188,  ..., -0.0109, -2.1562, -2.1250],
         [ 5.5312,  6.0000,  3.1875,  ..., -1.8281, -3.9531, -3.9688],
         [ 7.9062,  6.7188,  9.3750,  ..., -1.2969, -2.9844, -2.9531]]],
       device='cuda:0', dtype=torch.bfloat16)
(Pdb) outputs.logits.shape
torch.Size([1, 292, 151860])
'''

if synced_gpus and this_peer_finished:
    continue  # don't waste resources running the code we don't need

next_token_logits = outputs.logits[:, -1, :] # [1, 1, 151860],也就是最后一个 token

# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
     if output_scores:
         scores += (next_token_scores,)
     if output_logits:
         raw_logits += (next_token_logits,)
     if output_attentions:
         decoder_attentions += (
             (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
         )
         if self.config.is_encoder_decoder:
             cross_attentions += (outputs.cross_attentions,)

     if output_hidden_states:
         decoder_hidden_states += (
             (outputs.decoder_hidden_states,)
             if self.config.is_encoder_decoder
             else (outputs.hidden_states,)
         )

# sample
probs = nn.functional.softmax(next_token_scores, dim=-1) # tensor([[0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0'),会有一个为 1
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # tensor([18493], device='cuda:0')
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
   if pad_token_id is None:
       raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
   next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
   streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
   outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
)

# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
   unfinished_sequences = unfinished_sequences.mul(
       next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
   )

unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)

# stop when each sentence is finished
if unfinished_sequences.max() == 0:
   this_peer_finished = True

if this_peer_finished and not synced_gpus:
   break

if streamer is not None:
     streamer.end()

if return_dict_in_generate:
    if self.config.is_encoder_decoder:
        return GenerateEncoderDecoderOutput(
            sequences=input_ids,
            scores=scores,
            logits=raw_logits,
            encoder_attentions=encoder_attentions,
            encoder_hidden_states=encoder_hidden_states,
            decoder_attentions=decoder_attentions,
            cross_attentions=cross_attentions,
            decoder_hidden_states=decoder_hidden_states,
            past_key_values=model_kwargs.get("past_key_values"),
        )
    else:
        return GenerateDecoderOnlyOutput(
            sequences=input_ids,
            scores=scores,
            logits=raw_logits,
            attentions=decoder_attentions,
            hidden_states=decoder_hidden_states,
            past_key_values=model_kwargs.get("past_key_values"),
        )
else:
    return input_ids
  • 第一轮:对输入图像进行特征提取,替换掉 input_embeded 中 imgpad 部分(也就是图片token)的编码,将替换后的多模态 input_embeded (torch.Size([292, 2048])) 和 input_ids (torch.Size([1, 292]))送入 QWen 大语言模型进行特征提取,得到输出,维度为 torch.Size([1, 292, 2048]),[1, 292, 2048] 维度的多模态特征经过 lm_head,变成 [1, 292, 151860],然后选择最后一个 token 的特征作为输入,预测出下一个可能出现的 token id,这里的预测使用的就是 torch.multinomial(probs, num_samples=1).squeeze(1) 预测的,然后就会得到第一个预测到的 token id(tensor([18493], device=‘cuda:0’))
  • 第二轮:将第一轮预测的 token 作为输入(这里就是 generate 中的 sample 方法会一直调用 self 方法,也就是 vary 方法),不经过 vision tower 的处理,而是只经过 QWen 语言模型的处理,然后得到下一个 token 输出,维度为 torch.Size([1, 1, 2048])
  • 第三轮:将第二轮预测的 token 作为输入,预测下一个输出
  • 第 N 轮:同上
  • 最终流式的得到一句话

在训练和推理(生成)阶段的不同:

训练阶段:

在训练阶段,模型通常使用了称为“教师强制”(Teacher Forcing)的技术。在这种方法中,模型在每一时间步都被提供了正确的上下文,即使它在之前的时间步预测错误。这意味着模型的每一步输入都是真实的序列,而不是由模型自己生成的序列。这样做的好处是可以加速训练过程,因为模型不需要等待它自己生成下一个token,而是立即获得下一个正确的token作为输入。

此外,训练时通常使用的是交叉熵损失函数,这需要同时知道模型对整个序列的预测和真实序列,以便计算损失。因此,在训练时,模型会一次性看到整个序列,并一次性计算所有输出的概率分布,然后通过反向传播更新权重。

推理(生成)阶段:

而在推理阶段,目标是生成新的文本序列。由于我们不知道目标序列是什么,模型需要根据已生成的内容自行预测下一个token。在这个过程中,模型使用自回归的方式,每次生成一个token,然后将其作为下一步的输入。这个过程是迭代的,因为每个新的预测都建立在前一个预测的基础上。

在自回归生成中,模型不能一次性看到整个正确的序列,因为生成是动态进行的。因此,它需要逐步生成,每次添加一个新的元素到序列中。这种方法允许模型根据已生成的序列的上下文来生成下一个最合适的token。

综上所述,训练阶段和推理阶段的差异主要在于是否有访问整个正确序列的信息,以及模型是一次性生成整个序列的概率分布还是逐步生成序列。在训练时使用教师强制可以提高训练效率,而在推理时使用自回归生成则可以在没有正确答案的情况下生成连贯的文本。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

呆呆的猫

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值