基于 Python 的自然语言处理系列文章 (80):CoCa 模型原理与实现

一、CoCa 模型简介

        CoCa(Contrastive Captioner)是 Google 提出的一种图文联合学习模型,核心思想是统一 图文对比学习图像字幕生成任务,实现一种通用的图文基础模型(Image-Text Foundation Model)。

        CoCa 由以下三部分构成:

  • 视觉编码器(Vision Transformer):提取图像 patch 表征;

  • 语言建模器(Transformer):对文本进行建模与解码;

  • 双任务目标

    • 文本生成任务(Captioning loss):通过自回归的方式生成描述;

    • 对比学习任务(Contrastive loss):类似 CLIP,通过最大化图文对的匹配相似度进行表示学习。

        它的训练目标是将图文理解与生成统一在一个模型中,同时优化两个目标函数,提升模型在下游多模态任务上的泛化能力。

二、环境配置与依赖导入

import torch
from torch import einsum, nn
import torch.nn.functional as F
from einops import rearrange, repeat

三、CoCa 模型各模块详解

1. LayerNorm 与 Residual

        为确保数值稳定性,CoCa 自定义了不含 bias 的 LayerNorm:

class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer("beta", torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

        同时实现了 Residual 模块,用于残差连接:

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

2. Rotary Positional Embedding(旋转位置编码)

        用于 Transformer 模块中的位置编码:

class RotaryEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, max_seq_len, *, device):
        seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
        freqs = einsum("i , j -> i j", seq, self.inv_freq)
        return torch.cat((freqs, freqs), dim=-1)

3. Transformer Block

        使用了并行注意力机制:

class ParallelTransformerBlock(nn.Module):
    def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
        ...
        self.rotary_emb = RotaryEmbedding(dim_head)
        self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
        ...
    
    def forward(self, x, attn_mask=None):
        ...
        q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
        ...
        q = apply_rotary_pos_emb(self.get_rotary_embedding(n, device), q)
        ...
        out = einsum("b h i j, b j d -> b h i d", attn, v)
        return self.attn_out(rearrange(out, "b h n d -> b n (h d)")) + self.ff_out(ff)

        搭配旋转操作 apply_rotary_pos_embrotate_half,实现绝对位置嵌入替代方案。        

4. 图文交叉注意力层:CrossAttention

class CrossAttention(nn.Module):
    def __init__(self, dim, context_dim=None, dim_head=64, heads=8, ...):
        ...
        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)

    def forward(self, x, context):
        ...
        sim = einsum('b h i d, b j d -> b h i j', q, k)
        ...
        out = einsum('b h i j, b j d -> b h i d', attn, v)
        ...
        return self.to_out(out) + self.ff(x) if self.ff else self.to_out(out)

        用于图文信息融合,文本 attends to 图像 patch tokens。

四、主模型结构:CoCa 类实现

class CoCa(nn.Module):
    def __init__(..., caption_loss_weight=1.0, contrastive_loss_weight=1.0, ...):
        ...
        self.token_emb = nn.Embedding(num_tokens, dim)
        self.text_cls_token = nn.Parameter(torch.randn(dim))
        self.img_attn_pool = CrossAttention(...)
        ...
        self.unimodal_layers = nn.ModuleList([...])
        self.multimodal_layers = nn.ModuleList([...])
        ...
        self.to_logits = nn.Sequential(
            LayerNorm(dim),
            nn.Linear(dim, num_tokens, bias=False)
        )
        ...

        forward 流程:

  1. embed_text:文本序列加 CLS token → unimodal transformer → text embedding

  2. embed_image:ViT 编码图像 → attention pool → 图像 CLS / patch tokens

  3. forward:multimodal 层:文本 attends 图像、计算 logits(caption)和 latent(contrastive)、最终返回 loss / logits / embedding

五、完整训练示例

1. 加载图像编码器(ViT)

from vit_pytorch import ViT
from vit_pytorch.extractor import Extractor

vit = ViT(
    image_size=256,
    patch_size=32,
    num_classes=1000,
    dim=1024,
    depth=6,
    heads=16,
    mlp_dim=2048
)
vit = Extractor(vit, return_embeddings_only=True)

2. 构建 CoCa 模型

coca = CoCa(
    dim=512,
    img_encoder=vit,
    image_dim=1024,
    num_tokens=20000,
    unimodal_depth=6,
    multimodal_depth=6,
    dim_head=64,
    heads=8,
    caption_loss_weight=1.0,
    contrastive_loss_weight=1.0
).to(device)

3. 输入数据与训练

text = torch.randint(0, 20000, (4, 512)).to(device)
images = torch.randn(4, 3, 256, 256).to(device)

loss = coca(
    text=text,
    images=images,
    return_loss=True
)
loss.backward()

4. 获取模型输出

  • 获取 caption logits(生成任务)

logits = coca(
    text=text,
    images=images
)  # shape: (batch, seq_len, vocab_size)
  • 获取 text/image embedding(检索任务)

text_embeds, image_embeds = coca(
    text=text,
    images=images,
    return_embeddings=True
)  # shape: (batch, dim)

六、总结

        CoCa 作为一款统一图文基础模型,设计非常 elegant:

  • ✨ 对比目标拉近图文全局表示

  • ✨ Captioning 实现图像生成文字

  • ✨ 结构解耦,便于灵活扩展

  • ✨ 图像输入统一为 patch embedding + attention pooling

        从零样本图文检索到高质量图像描述,CoCa 展现了强大的多模态建模能力。

        📌 下一篇将进入 RLHF(Reinforcement Learning with Human Feedback)相关内容,开启多轮对话训练的范式探索,敬请期待!

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

会飞的Anthony

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值