ImageBind-LLM: Multi-modality Instruction Tuning 论文阅读笔记

本文主要基于LLaMA和ImageBind工作,结合多模态信息和文本指令来实现一系列任务。训练中仅使用图像文本信息作为多模态信息提取能力的训练数据(only leverage the vision-language data for multi-modality instruction tuning)。Github代码 link.

Method 方法

对于一个图像文本对,

  1. 使用来自ImageBind工作、预训练好、冻结参数的图像encoder来提取全局的图像特征(utilize the frozen image encoder of ImageBind to extract the global image feature)。
  2. 使用一个可学习的bind network来对齐 前面ImageBind encoder 和 后面LLaMA的特征空间,得到处理后的transformed image feature(adopt a learnable bind network to align the embedding space between LLaMA and ImageBind’s image encoder)。
  3. 将图像特征(多模态数据特征)transformed image feature与LLaMA的文本知识融合:将transformed image feature与LLaMA中每个transformer层的每个word tokens相加(the transformed image feature is added to the word tokens at all transformer layers in LLaMA)。并且设置了一个初始值为0、可学习的门参数 g z e r o g_{zero} gzero来控制特征融合的程度,
    T j = T I ∗ g z e r o + T W j T^j=T_I*g_{zero} + T{_W}{^j} Tj=TIgzero+TWj
    门参数的设置可以使得模型训练初期保持稳定,门参数的数值一般随着训练会逐渐增加。

所以整个模型可以分为两个阶段的训练,

  1. vision-language pretraining on image-caption data to learn the image-conditioned response capacity
    基于ImageBind的encoder,模型也可以理解图像之外其他模态的信息
  2. multi-modality instruction tuning on visual instruction data
    基于non-instruction model LLaMA,输入文本指令(language instruction)来学习长句生成能力(long-sentence generation quality)。本阶段仅使用图像文本数据来微调模型,并且冻结Imagebind encoder和Bind network的参数。

在这里插入图片描述

Bind Network

主要作用是对齐ImageBind和LLaMA之间的特征空间。
在这里插入图片描述
代码实现:

import torch
import torch.nn as nn
import torch.nn.functional as F

# Define the RMSNorm 
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight
      

# Define the repeated feedforward block in bind network 
class FeedForwardBlock(nn.Module):
    def __init__(self, dim: int, hidden_dim: int):
        super().__init__()

        # normalize the input 
        self.norm = RMSNorm(dim)

        # Define 3 linear projection layers whose parameters are w1, w2 and w3 respectively.
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(dim, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, dim, bias=False)

    def forward(self, x):
        # cascade linear linears with RMSNorm, SiLU activation functions and residual connections
        x = self.norm(x)
        return x + self.w3(F.silu(self.w1(x)) * self.w2(x))

class bind_network(nn.Module):
    def __init__(self, args):
        super.__init__()
        self.image_dim = args.image_dim # e.g., 1024, encoded by ImageBind
        self.model_dim = args.model_dim # e.g., 4096
        self.ffn_dim = self.model_dim * 4 # 

        self.linear_0 = nn.Linear(self.image_dim, self.model_dim)

        self.feed_forward_1 = FeedForwardBlock(dim=self.model_dim, hidden_dim=self.ffn_dim)
        self.feed_forward_2 = FeedForwardBlock(dim=self.model_dim, hidden_dim=self.ffn_dim)
        self.feed_forward_3 = FeedForwardBlock(dim=self.model_dim, hidden_dim=self.ffn_dim)

    def forward(self, image_feature):
        # image_feature, (1,C1) / (1,image_dim)
        
        # Adopt the linear projection layer at first
        image_feature = self.linear_0(image_feature) # image_feature, (1, model_dim)

        # Cascade 3 projection blocks 
        image_feature = self.feed_forward_1(image_feature)
        image_feature = self.feed_forward_2(image_feature)
        transformed_image_feature = self.feed_forward_3(image_feature)
        return transformed_image_feature 

RMSNorm的原理及与Layer Norm的对比

计算过程,对于输入向量 x ∈ R m x∈R^m xRm

  1. 首先计算输入向量与权重矩阵的加权和,
    在这里插入图片描述
  2. 标准化 Normalization

LayerNorm的计算方法,
在这里插入图片描述

RMSNorm的计算方法,
在这里插入图片描述
故RMSNorm完整减少了计算加权和平均值μ的步骤,保证模型与输入向量和权重解耦、训练过程中梯度稳定及模型收敛速度的前提下,减少了额外的计算开销,加速7%~64%的网络训练(具体的提升指标受硬件、网络结构、其他部分计算开销等影响)。

  1. 加上偏置和激活函数,获得该层的输出
    在这里插入图片描述

Related Word / Prior Work

LLaMA-Adapter

模型输入图像 (image inputs),输出文本(language responses)。

Pipeline:

  1. 使用预训练好的encoder来提取图像特征;
  2. 将图像特征输入LLaMA进行微调。具体的实现方法是将图像特征作为token,拼接到LLaMA输入的word tokens前(LLaVA和MiniGPT-4中也使用同样的concat做法,这样导致数据长度变长、需调用self-attention mechanism,所以会导致额外的计算和训练难度的提示);并且在每一个attention layer前,设置一个初始值为0的、可学习的门参数(zero-initialized gating factor)来调节特征拼接的程度。

局限:只能解决简单的视觉问答(visual question answering scenarios)问题,例如ScienceQA

联系我们

OceanneDLG@outlook.com

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值