InfoNCE 在 CLIP 中的应用原理(代码实现):为什么用交叉熵F.cross_entropy来实现?

InfoNCE 在 CLIP 中的应用原理

CLIP(Contrastive Language-Image Pretraining)是由 OpenAI 提出的一种跨模态预训练模型,旨在通过对比学习对齐图像和文本的表示。它的核心思想是利用 InfoNCE Loss 让匹配的图像-文本对(正样本)的相似度得分高于不匹配的图像-文本对(负样本),从而学习到强大的双模态表示。

CLIP 的工作机制

  1. 输入

    • 一个批次包含 ( N N N ) 个图像 ( I 1 , I 2 , … , I N I_1, I_2, \ldots, I_N I1,I2,,IN ) 和对应的 ( N N N ) 个文本描述 ( T 1 , T 2 , … , T N T_1, T_2, \ldots, T_N T1,T2,,TN )。
    • 这些图像-文本对是正样本对,例如 ( ( I i , T i ) (I_i, T_i) (Ii,Ti) ) 是匹配的。
  2. 嵌入表示

    • 图像通过视觉编码器(通常是 Vision Transformer 或 ResNet)转换为嵌入向量 ( z I = [ z I 1 , z I 2 , … , z I N ] z_I = [z_{I_1}, z_{I_2}, \ldots, z_{I_N}] zI=[zI1,zI2,,zIN] ),维度为 ( N × d N \times d N×d )。
    • 文本通过文本编码器(通常是 Transformer)转换为嵌入向量 ( z T = [ z T 1 , z T 2 , … , z T N ] z_T = [z_{T_1}, z_{T_2}, \ldots, z_{T_N}] zT=[zT1,zT2,,zTN] ),维度也为 ( N × d N \times d N×d )。
  3. 相似度计算

    • 计算图像和文本嵌入之间的相似度矩阵 ( S S S ),通常使用余弦相似度:
      S i , j = z I i ⋅ z T j ∥ z I i ∥ ∥ z T j ∥ S_{i,j} = \frac{z_{I_i} \cdot z_{T_j}}{\|z_{I_i}\| \|z_{T_j}\|} Si,j=zIi∥∥zTjzIizTj
    • ( S S S ) 是一个 ( N × N N \times N N×N ) 的矩阵,其中对角线元素 ( S i , i S_{i,i} Si,i ) 表示正样本对的相似度,其他元素 ( S i , j ( i ≠ j ) S_{i,j} (i \neq j) Si,j(i=j) ) 表示负样本对的相似度。
  4. InfoNCE Loss 的应用
    请参考笔者的另一篇博客:
    深入解析 InfoNCE Loss:对比学习的基石(是在什么背景下提出来的?)

    • CLIP 在两个方向上应用 InfoNCE Loss:
      • 图像到文本:对于每个图像 ( I i I_i Ii ),目标是让 ( S i , i S_{i,i} Si,i )(匹配的文本 ( T i T_i Ti ))的得分高于其他 ( S i , j ( j ≠ i ) S_{i,j} (j \neq i) Si,j(j=i) )(不匹配的文本)。
      • 文本到图像:对于每个文本 ( T i T_i Ti ),目标是让 ( S i , i S_{i,i} Si,i )(匹配的图像 ( I i I_i Ii ))的得分高于其他 ( S j , i ( j ≠ i ) S_{j,i} (j \neq i) Sj,i(j=i) )(不匹配的图像)。
    • 损失函数的形式为:
      L image = − 1 N ∑ i = 1 N log ⁡ exp ⁡ ( S i , i / τ ) ∑ j = 1 N exp ⁡ ( S i , j / τ ) \mathcal{L}_{\text{image}} = -\frac{1}{N} \sum_{i=1}^N \log \frac{\exp(S_{i,i} / \tau)}{\sum_{j=1}^N \exp(S_{i,j} / \tau)} Limage=N1i=1Nlogj=1Nexp(Si,j/τ)exp(Si,i/τ)
      L text = − 1 N ∑ i = 1 N log ⁡ exp ⁡ ( S i , i / τ ) ∑ j = 1 N exp ⁡ ( S j , i / τ ) \mathcal{L}_{\text{text}} = -\frac{1}{N} \sum_{i=1}^N \log \frac{\exp(S_{i,i} / \tau)}{\sum_{j=1}^N \exp(S_{j,i} / \tau)} Ltext=N1i=1Nlogj=1Nexp(Sj,i/τ)exp(Si,i/τ)
      • ( τ \tau τ ) 是一个可学习的温度参数,用于缩放相似度得分。
      • 总损失是两个方向的平均:
        L = 1 2 ( L image + L text ) \mathcal{L} = \frac{1}{2} (\mathcal{L}_{\text{image}} + \mathcal{L}_{\text{text}}) L=21(Limage+Ltext)
  5. 目标

    • 通过最小化 ( L \mathcal{L} L ),CLIP 学习到的图像和文本表示在潜在空间中对齐,使得匹配对的相似度高,不匹配对的相似度低。

为什么用 InfoNCE?

  • InfoNCE Loss 通过对比正样本和负样本,最大化图像和文本之间的互信息(Mutual Information),从而让模型捕获跨模态的语义关系。
  • 在一个批次中,每个图像和文本的负样本数量为 ( N − 1 N-1 N1 ),随着批次大小增加,模型的区分能力更强。

CLIP 的训练代码和推理代码

以下是用 PyTorch 实现的简化和示例性 CLIP 训练和推理代码。假设我们使用预训练的视觉和文本编码器(如 Vision Transformer 和 BERT),并专注于 InfoNCE Loss 的实现。

训练代码

import torch
import torch.nn as nn
import torch.nn.functional as F

# 假设的图像和文本编码器(可以替换为实际模型)
class ImageEncoder(nn.Module):
    def __init__(self, embed_dim=512):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * 16 * 16, embed_dim)
        )
    
    def forward(self, x):
        return self.encoder(x)

class TextEncoder(nn.Module):
    def __init__(self, embed_dim=512):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Embedding(10000, 128),  # 假设词汇表大小为10000
            nn.LSTM(128, 128, batch_first=True),
            nn.Linear(128, embed_dim)
        )
    
    def forward(self, x):
        embeddings = self.encoder[0](x)
        _, (hidden, _) = self.encoder[1](embeddings)
        return self.encoder[2](hidden[-1])

# CLIP 模型
class CLIP(nn.Module):
    def __init__(self, embed_dim=512):
        super().__init__()
        self.image_encoder = ImageEncoder(embed_dim)
        self.text_encoder = TextEncoder(embed_dim)
        self.logit_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1 / 0.07)))  # 初始温度 0.07
    
    def forward(self, images, texts):
        # 获取嵌入
        image_features = self.image_encoder(images)
        text_features = self.text_encoder(texts)
        
        # L2 归一化
        image_features = F.normalize(image_features, p=2, dim=1)
        text_features = F.normalize(text_features, p=2, dim=1)
        
        # 计算相似度矩阵
        logit_scale = self.logit_scale.exp()
        logits = logit_scale * image_features @ text_features.T  # [N, N]
        
        return logits

# InfoNCE Loss 计算
def info_nce_loss(logits):
    N = logits.shape[0]
    labels = torch.arange(N).to(logits.device)  # 对角线为正样本
    loss_img = F.cross_entropy(logits, labels)  # 图像到文本
    loss_txt = F.cross_entropy(logits.T, labels)  # 文本到图像
    return (loss_img + loss_txt) / 2

# 训练循环
def train_clip():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = CLIP().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    
    # 假设数据:图像 [N, 3, 32, 32],文本 [N, seq_len]
    batch_size = 32
    images = torch.randn(batch_size, 3, 32, 32).to(device)
    texts = torch.randint(0, 10000, (batch_size, 10)).to(device)
    
    for epoch in range(10):
        optimizer.zero_grad()
        logits = model(images, texts)
        loss = info_nce_loss(logits)
        loss.backward()
        optimizer.step()
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

if __name__ == "__main__":
    train_clip()

推理代码

import torch
import torch.nn.functional as F

# 使用训练好的 CLIP 模型进行推理
def inference_clip(model, images, texts, device="cuda"):
    model.eval()
    with torch.no_grad():
        # 获取嵌入
        image_features = model.image_encoder(images.to(device))
        text_features = model.text_encoder(texts.to(device))
        
        # L2 归一化
        image_features = F.normalize(image_features, p=2, dim=1)
        text_features = F.normalize(text_features, p=2, dim=1)
        
        # 计算相似度
        logit_scale = model.logit_scale.exp()
        logits = logit_scale * image_features @ text_features.T  # [N_images, N_texts]
        
        # 返回相似度得分或预测
        return logits

# 示例推理
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = CLIP().to(device)
    
    # 假设测试数据
    test_images = torch.randn(2, 3, 32, 32).to(device)  # 2 张图像
    test_texts = torch.randint(0, 10000, (3, 10)).to(device)  # 3 个文本描述
    
    logits = inference_clip(model, test_images, test_texts)
    print("Similarity scores:\n", logits.cpu().numpy())
    
    # 零样本分类:选择最匹配的文本
    predictions = logits.argmax(dim=1)
    print("Predicted text indices:", predictions.cpu().numpy())

if __name__ == "__main__":
    main()

代码说明

训练代码

  1. 模型结构
    • ImageEncoderTextEncoder 是简化的占位符,实际中可以替换为 Vision Transformer 和 BERT。
    • CLIP 类包含两个编码器和一个可学习的温度参数 ( τ \tau τ )(通过 logit_scale 表示)。
  2. InfoNCE Loss
    • 使用 F.cross_entropy 实现,因为它等价于 InfoNCE Loss 的对数 softmax 形式。(下文有详细解释)
    • 对角线元素是正样本,损失在图像到文本和文本到图像两个方向上对称计算。
  3. 训练循环
    • 随机生成数据进行演示,实际中需要真实图像-文本对数据集(如 LAION)。

推理代码

  1. 推理过程
    • 输入测试图像和文本,计算它们的嵌入和相似度矩阵。
    • 返回相似度得分,可用于零样本分类(选择得分最高的文本)。
  2. 零样本分类
    • 通过 argmax 找到每个图像最匹配的文本索引,展示 CLIP 的分类能力。

总结

CLIP 通过 InfoNCE Loss 在批次内构造正负样本对,训练图像和文本编码器对齐表示。训练代码展示了如何优化这一损失,而推理代码展示了如何利用训练好的模型进行跨模态任务。这种方法简单而高效,证明了 InfoNCE Loss 在跨模态学习中的普适性和强大能力。

使用cross_entropy实现解释

详细解释为什么在 CLIP 的实现中,使用 F.cross_entropy 等价于 InfoNCE Loss 的对数 softmax 形式,以及 labels = torch.arange(N) 的作用和意义。



为什么 F.cross_entropy 等价于 InfoNCE Loss?

PyTorch 的 F.cross_entropy 函数本质上是交叉熵损失的实现,它结合了 softmax 和负对数似然(Negative Log-Likelihood, NLL)计算。我们来一步步分析为什么它能等价于 InfoNCE Loss 的对数 softmax 形式。

1. 交叉熵损失的定义

对于一个分类任务,假设有 ( N N N ) 个类别,输入是 logits ( z = [ z 1 , z 2 , … , z N ] z = [z_1, z_2, \ldots, z_N] z=[z1,z2,,zN] )(未归一化的得分),真实标签是 ( y y y )(类别索引),交叉熵损失为:

L = − log ⁡ ( exp ⁡ ( z y ) ∑ j = 1 N exp ⁡ ( z j ) ) L = -\log \left( \frac{\exp(z_y)}{\sum_{j=1}^N \exp(z_j)} \right) L=log(j=1Nexp(zj)exp(zy))

  • 其中 ( exp ⁡ ( z y ) ∑ j = 1 N exp ⁡ ( z j ) \frac{\exp(z_y)}{\sum_{j=1}^N \exp(z_j)} j=1Nexp(zj)exp(zy) ) 是 softmax 后的概率。
  • ( z y z_y zy ) 是正确类别的 logit。

F.cross_entropy 会自动对输入的 logits 应用 softmax,然后计算指定标签的负对数概率。

2. InfoNCE Loss 的形式

以图像到文本方向为例:

L image = − 1 N ∑ i = 1 N log ⁡ exp ⁡ ( S i , i / τ ) ∑ j = 1 N exp ⁡ ( S i , j / τ ) \mathcal{L}_{\text{image}} = -\frac{1}{N} \sum_{i=1}^N \log \frac{\exp(S_{i,i} / \tau)}{\sum_{j=1}^N \exp(S_{i,j} / \tau)} Limage=N1i=1Nlogj=1Nexp(Si,j/τ)exp(Si,i/τ)

  • 对于每个图像 ( I i I_i Ii ):
    • ( S i , i / τ S_{i,i} / \tau Si,i/τ ) 是正样本的 logit。
    • ( [ S i , 1 / τ , S i , 2 / τ , … , S i , N / τ ] [S_{i,1} / \tau, S_{i,2} / \tau, \ldots, S_{i,N} / \tau] [Si,1/τ,Si,2/τ,,Si,N/τ] ) 是所有候选文本的 logits。
    • 我们希望模型预测 ( S i , i / τ S_{i,i} / \tau Si,i/τ ) 是“正确类别”的 logit。

这与交叉熵的形式完全一致:

  • ( S i , j / τ S_{i,j} / \tau Si,j/τ ) 对应于 logits ( z j z_j zj )。
  • 正样本索引 ( i i i ) 对应于正确标签 ( y y y )。
  • ( exp ⁡ ( S i , i / τ ) ∑ j = 1 N exp ⁡ ( S i , j / τ ) \frac{\exp(S_{i,i} / \tau)}{\sum_{j=1}^N \exp(S_{i,j} / \tau)} j=1Nexp(Si,j/τ)exp(Si,i/τ) ) 是 softmax 概率。

因此,对于每一行 ( S [ i , : ] S[i, :] S[i,:] )(图像 ( I i I_i Ii) 对所有文本的相似度),我们可以将其视为一个分类问题,目标是让模型正确预测第 ( i i i ) 个文本(正样本)。

3. 在 CLIP 中的实现

在代码中:

loss_img = F.cross_entropy(logits, labels)
  • logits 是一个 ( N × N N \times N N×N ) 的矩阵,每行 ( l o g i t s [ i , : ] logits[i, :] logits[i,:] ) 是图像 ( I i I_i Ii ) 对所有文本的相似度得分(即 ( S i , j / τ S_{i,j} / \tau Si,j/τ ))。
  • labels[0, 1, 2, ..., N-1],表示每个图像的正确文本索引(对角线元素)。
  • F.cross_entropy 对每行应用 softmax,然后计算负对数概率,平均后正好是 ( L image \mathcal{L}_{\text{image}} Limage )。

类似地:

loss_txt = F.cross_entropy(logits.T, labels)
  • logits.T 是 ( N × N N \times N N×N ) 的转置矩阵,每行 ( l o g i t s . T [ i , : ] logits.T[i, :] logits.T[i,:] ) 是文本 ( T i T_i Ti ) 对所有图像的相似度得分。
  • 计算过程与图像到文本方向对称。

4. 等价性总结

  • InfoNCE Loss 的对数 softmax 形式是手动计算 softmax 概率并取负对数。
  • F.cross_entropy 内置了这一过程,直接接受未归一化的 logits 和目标标签,计算结果与 InfoNCE Loss 的数学形式一致。
  • 两个方向的损失分别对应于图像到文本和文本到图像的对比任务,平均后实现对称优化。

labels = torch.arange(N) 的作用

labels = torch.arange(N) 是用来生成目标标签的,具体作用如下:

1. 生成正样本的索引

  • 在 CLIP 中,批次内的图像和文本是成对匹配的:
    • 图像 ( I 0 I_0 I0 ) 匹配文本 ( T 0 T_0 T0 )(索引 0)。
    • 图像 ( I 1 I_1 I1 ) 匹配文本 ( T 1 T_1 T1 )(索引 1)。
    • 以此类推,直到 ( I N − 1 I_{N-1} IN1 ) 匹配 ( T N − 1 T_{N-1} TN1 )(索引 ( N − 1 N-1 N1 ))。
  • torch.arange(N) 生成一个从 0 到 ( N − 1 N-1 N1 ) 的序列:[0, 1, 2, ..., N-1]
  • 这个序列表示相似度矩阵 ( S S S ) 中每一行的“正确类别”索引,也就是对角线元素 ( S i , i S_{i,i} Si,i ) 的位置。

2. 与 F.cross_entropy 的配合

  • F.cross_entropy 需要两个输入:
    • logits:预测的得分矩阵(未归一化的 logits)。
    • labels:每一行的正确类别索引。
  • 对于 logits[i, :](图像 ( I i I_i Ii ) 的得分),labels[i] = i 表示正确类别是第 ( i i i ) 个文本(即 ( T i T_i Ti ))。
  • 例如:
    • 如果 ( N = 3 N = 3 N=3 ),labels = [0, 1, 2]
    • 对于 logits[0, :],正确类别是 0(( S 0 , 0 S_{0,0} S0,0 ))。
    • 对于 logits[1, :],正确类别是 1(( S 1 , 1 S_{1,1} S1,1 ))。

3. 为什么用对角线?

  • 在 CLIP 的相似度矩阵 ( S S S ) 中,对角线元素 ( S i , i S_{i,i} Si,i ) 对应正样本对(匹配的图像-文本对)。
  • 非对角线元素 ( S i , j ( i ≠ j ) S_{i,j} (i \neq j) Si,j(i=j) ) 是负样本对(不匹配的图像-文本对)。
  • labels = torch.arange(N) 确保模型的目标是最大化对角线元素的 softmax 概率,与 InfoNCE Loss 的目标一致。

4. 代码中的具体意义

labels = torch.arange(N).to(logits.device)
  • 生成 [0, 1, ..., N-1] 并移动到与 logits 相同的设备(例如 GPU)。
  • 这告诉 F.cross_entropy,对于第 ( i i i ) 行,正确类别是第 ( i i i ) 列,从而实现 InfoNCE Loss 的正样本优化。

完整解释与示例

假设 ( N = 3 N = 3 N=3 ),相似度矩阵 ( S S S ) 为:

S = [[2.0, 0.5, 0.1],  # I_0 对 T_0, T_1, T_2
     [0.3, 1.8, 0.4],  # I_1 对 T_0, T_1, T_2
     [0.2, 0.6, 1.5]]  # I_2 对 T_0, T_1, T_2
  • labels = [0, 1, 2]
  • 图像到文本
    • 对于第一行 [2.0, 0.5, 0.1],正确类别是 0,F.cross_entropy 计算:
      − log ⁡ ( exp ⁡ ( 2.0 ) exp ⁡ ( 2.0 ) + exp ⁡ ( 0.5 ) + exp ⁡ ( 0.1 ) ) -\log \left( \frac{\exp(2.0)}{\exp(2.0) + \exp(0.5) + \exp(0.1)} \right) log(exp(2.0)+exp(0.5)+exp(0.1)exp(2.0))
    • 类似地处理其他行,平均后得到 ( L image \mathcal{L}_{\text{image}} Limage )。
  • 文本到图像
    • 对 ( S . T S.T S.T ) 重复相同过程。

这与手动计算 InfoNCE Loss 的结果一致,只是 F.cross_entropy 更高效。


总结

  1. 为什么等价
    • F.cross_entropy 内置了 softmax 和负对数计算,与 InfoNCE Loss 的对数 softmax 形式数学上等价。
    • 它将对比学习任务转化为分类问题,每一行是一个多分类任务,正样本是正确类别。
  2. labels = torch.arange(N) 的作用
    • 生成对角线元素的索引,表示正样本的位置。
    • 配合 F.cross_entropy,确保损失函数优化的是正样本的 softmax 概率。

这种实现方式既简洁又高效,完美体现了 InfoNCE Loss 在 CLIP 中的应用逻辑。

后记

2025年3月29日20点32分于上海,在grok 3大模型辅助下完成。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值