对比损失(Contrastive Loss)与大模型:从原理到实践
在现代深度学习中,对比损失(Contrastive Loss)是一种核心技术,尤其是在对比学习(Contrastive Learning)中被广泛使用。通过最大化正样本之间的相似度、最小化负样本之间的相似度,对比损失有效地增强了表示学习的能力。在大模型时代,例如 CLIP、SimCLR、DINO 等,都依赖对比损失来推动模型性能的提升。
1. 什么是对比损失?
对比损失的定义
对比损失的目标是学习一个嵌入空间,在该空间中:
- 正样本对(positive pairs)(例如同一图像的不同视角,或一张图像和其对应的文本描述)距离尽可能近;
- 负样本对(negative pairs)(例如不相关的图像和文本)距离尽可能远。
公式化描述如下,以批量大小 (
N
N
N ) 的样本为例:
L
contrastive
=
−
1
N
∑
i
=
1
N
log
exp
(
sim
(
z
i
,
z
i
+
)
/
τ
)
∑
j
=
1
N
exp
(
sim
(
z
i
,
z
j
)
/
τ
)
\mathcal{L}_{\text{contrastive}} = -\frac{1}{N}\sum_{i=1}^N \log \frac{\exp(\text{sim}(z_i, z_i^+) / \tau)}{\sum_{j=1}^N \exp(\text{sim}(z_i, z_j) / \tau)}
Lcontrastive=−N1i=1∑Nlog∑j=1Nexp(sim(zi,zj)/τ)exp(sim(zi,zi+)/τ)
- ( z i z_i zi ) 和 ( z i + z_i^+ zi+ ) 是正样本对的嵌入向量。
- ( sim ( ⋅ , ⋅ ) \text{sim}(\cdot, \cdot) sim(⋅,⋅) ) 通常是余弦相似度。
- ( τ \tau τ ) 是温度参数,控制分布的“平滑”程度。
2. 对比损失与 Logits 的关系
在实现对比损失时,Logits 是计算相似度和分布的核心中间变量。
-
Logits 的构成:
- 对于每对样本 (
(
z
i
,
z
j
)
(z_i, z_j)
(zi,zj) ),通过相似度函数(如点积或余弦相似度)得到 logits 值:
logits i j = sim ( z i , z j ) \text{logits}_{ij} = \text{sim}(z_i, z_j) logitsij=sim(zi,zj) - Logits 是未归一化的相似性分数,直接参与 Softmax 分布的计算。
- 对于每对样本 (
(
z
i
,
z
j
)
(z_i, z_j)
(zi,zj) ),通过相似度函数(如点积或余弦相似度)得到 logits 值:
-
Softmax 正规化:
- Logits 被用来计算每个正样本对的条件概率:
P ( i + ∣ i ) = exp ( logits i i + / τ ) ∑ j = 1 N exp ( logits i j / τ ) P(i^+|i) = \frac{\exp(\text{logits}_{ii^+} / \tau)}{\sum_{j=1}^N \exp(\text{logits}_{ij} / \tau)} P(i+∣i)=∑j=1Nexp(logitsij/τ)exp(logitsii+/τ) - 这种概率分布直接反映正样本与其他样本的区分度。
- Logits 被用来计算每个正样本对的条件概率:
3. 对比损失的典型应用:大模型中的案例
(1) CLIP:图文对齐的典范
CLIP(Contrastive Language–Image Pretraining)是 OpenAI 提出的一个多模态模型,通过对比损失学习图像和文本的对齐关系。
- 用 Logits 做什么?
- CLIP 将图像和文本分别编码为嵌入向量 ( z image z_{\text{image}} zimage ) 和 ( z text z_{\text{text}} ztext )。
- 通过点积计算图像和文本之间的相似度,得到一个 (
N
×
N
N \times N
N×N ) 的 logits 矩阵,其中每一行表示一个图像与所有文本的相似度:
logits i j = z image , i ⋅ z text , j ∥ z image , i ∥ ∥ z text , j ∥ \text{logits}_{ij} = \frac{z_{\text{image}, i} \cdot z_{\text{text}, j}}{\|z_{\text{image}, i}\| \|z_{\text{text}, j}\|} logitsij=∥zimage,i∥∥ztext,j∥zimage,i⋅ztext,j - 使用对比损失,最大化正确图文对的概率,同时最小化错误配对的概率。
(2) SimCLR:无监督表征学习
SimCLR 是一种经典的无监督对比学习方法,使用数据增强生成正样本对。
- 用 Logits 做什么?
- SimCLR 将增强后的图像编码为嵌入,计算每对样本之间的余弦相似度作为 logits。
- 对每个样本,计算其增强版本与其他样本的对比分布,用对比损失优化模型。
(3) DINO:自监督视觉表征
DINO(Self-Distillation with No Labels)是一种自监督学习方法,通过对比损失优化不同视角的相似性。
- 用 Logits 做什么?
- DINO 使用教师网络和学生网络生成不同视角的嵌入,并通过 logits 计算两者的相似性分布。
- 使用对比损失对齐教师和学生的分布。
4. 实际例子:从 Logits 到对比损失
场景:CLIP 图文对齐
假设有 2 张图像和 2 条文本,编码后的嵌入如下:
- 图像嵌入:
z image , 1 = [ 1 , 0 ] , z image , 2 = [ 0 , 1 ] z_{\text{image}, 1} = [1, 0], \quad z_{\text{image}, 2} = [0, 1] zimage,1=[1,0],zimage,2=[0,1] - 文本嵌入:
z text , 1 = [ 1 , 0 ] , z text , 2 = [ 0 , 1 ] z_{\text{text}, 1} = [1, 0], \quad z_{\text{text}, 2} = [0, 1] ztext,1=[1,0],ztext,2=[0,1]
Step 1: 计算 Logits
通过点积计算 logits:
logits
=
[
1
0
0
1
]
\text{logits} = \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix}
logits=[1001]
Step 2: Softmax 归一化
将 logits 归一化为概率分布(温度参数 (
τ
=
1
\tau = 1
τ=1 )):
P
(
image
1
∣
text
1
)
=
exp
(
1
)
exp
(
1
)
+
exp
(
0
)
≈
0.73
P(\text{image}_1|\text{text}_1) = \frac{\exp(1)}{\exp(1) + \exp(0)} \approx 0.73
P(image1∣text1)=exp(1)+exp(0)exp(1)≈0.73
P
(
image
1
∣
text
2
)
=
exp
(
0
)
exp
(
1
)
+
exp
(
0
)
≈
0.27
P(\text{image}_1|\text{text}_2) = \frac{\exp(0)}{\exp(1) + \exp(0)} \approx 0.27
P(image1∣text2)=exp(1)+exp(0)exp(0)≈0.27
Step 3: 计算损失
对每个正确的图文对,计算交叉熵损失:
L
=
−
1
2
(
log
(
0.73
)
+
log
(
0.73
)
)
≈
0.63
\mathcal{L} = -\frac{1}{2}\big(\log(0.73) + \log(0.73)\big) \approx 0.63
L=−21(log(0.73)+log(0.73))≈0.63
5. 洞见与未来方向
洞见
-
Logits 是表示能力的核心:
Logits 直接反映模型对样本间关系的编码能力,其质量决定了对比损失的优化效果。 -
对比学习的鲁棒性:
对比损失在无监督和多模态任务中表现优异,能够有效学习出强大的嵌入表示。
未来方向
随着大模型的发展,对比损失将进一步结合更多模态(例如音频、视频),并通过改进训练策略(如温度参数调节、自适应权重分配等)提升性能。
总结
对比损失是大模型时代的关键技术,从 CLIP 的图文对齐到 SimCLR 的无监督学习,它通过 Logits 精确建模样本之间的相似性和差异性。通过对比损失,我们可以在嵌入空间中实现对正样本的高聚合和负样本的有效区分,为模型的实际应用提供坚实的基础。
Contrastive Loss and Large Models: Understanding Logits and Their Applications
Contrastive loss is a fundamental technique in deep learning, especially in contrastive learning. By maximizing the similarity between positive pairs and minimizing it between negative pairs, contrastive loss enhances a model’s ability to learn meaningful representations. In the era of large models, examples like CLIP, SimCLR, and DINO highlight the widespread use of contrastive loss to boost model performance.
1. What Is Contrastive Loss?
Definition
The goal of contrastive loss is to learn an embedding space where:
- Positive pairs (e.g., different augmentations of the same image, or an image and its corresponding text) are close to each other.
- Negative pairs (e.g., unrelated images and text) are far apart.
For a batch size (
N
N
N ), contrastive loss is defined as:
L
contrastive
=
−
1
N
∑
i
=
1
N
log
exp
(
sim
(
z
i
,
z
i
+
)
/
τ
)
∑
j
=
1
N
exp
(
sim
(
z
i
,
z
j
)
/
τ
)
\mathcal{L}_{\text{contrastive}} = -\frac{1}{N}\sum_{i=1}^N \log \frac{\exp(\text{sim}(z_i, z_i^+) / \tau)}{\sum_{j=1}^N \exp(\text{sim}(z_i, z_j) / \tau)}
Lcontrastive=−N1i=1∑Nlog∑j=1Nexp(sim(zi,zj)/τ)exp(sim(zi,zi+)/τ)
- ( z i z_i zi ) and ( z i + z_i^+ zi+ ) are embeddings of positive pairs.
- ( sim ( ⋅ , ⋅ ) \text{sim}(\cdot, \cdot) sim(⋅,⋅) ) is a similarity function, often cosine similarity.
- ( τ \tau τ ) is a temperature parameter controlling the sharpness of the distribution.
2. The Role of Logits in Contrastive Loss
Logits serve as the core intermediate variable in the computation of contrastive loss.
-
Logits Construction:
- For each sample pair (
(
z
i
,
z
j
)
(z_i, z_j)
(zi,zj) ), logits are derived from a similarity measure:
logits i j = sim ( z i , z j ) \text{logits}_{ij} = \text{sim}(z_i, z_j) logitsij=sim(zi,zj) - These logits represent unnormalized similarity scores.
- For each sample pair (
(
z
i
,
z
j
)
(z_i, z_j)
(zi,zj) ), logits are derived from a similarity measure:
-
Softmax Normalization:
- Logits are transformed into probabilities to compute the likelihood of positive pairs:
P ( i + ∣ i ) = exp ( logits i i + / τ ) ∑ j = 1 N exp ( logits i j / τ ) P(i^+|i) = \frac{\exp(\text{logits}_{ii^+} / \tau)}{\sum_{j=1}^N \exp(\text{logits}_{ij} / \tau)} P(i+∣i)=∑j=1Nexp(logitsij/τ)exp(logitsii+/τ) - This normalization helps the model focus on distinguishing between positive and negative pairs.
- Logits are transformed into probabilities to compute the likelihood of positive pairs:
3. Applications of Contrastive Loss in Large Models
(1) CLIP: Aligning Images and Text
CLIP (Contrastive Language–Image Pretraining) is a multimodal model by OpenAI that uses contrastive loss to align images and text.
- How CLIP Uses Logits:
- CLIP encodes images and text into embeddings, ( z image z_{\text{image}} zimage ) and ( z text z_{\text{text}} ztext ).
- Logits are computed via dot product similarity, producing a (
N
×
N
N \times N
N×N ) logits matrix where each row represents an image-text similarity distribution:
logits i j = z image , i ⋅ z text , j ∥ z image , i ∥ ∥ z text , j ∥ \text{logits}_{ij} = \frac{z_{\text{image}, i} \cdot z_{\text{text}, j}}{\|z_{\text{image}, i}\| \|z_{\text{text}, j}\|} logitsij=∥zimage,i∥∥ztext,j∥zimage,i⋅ztext,j - Contrastive loss maximizes the probability of the correct image-text pair while minimizing others.
(2) SimCLR: Self-Supervised Representation Learning
SimCLR is a self-supervised method that learns representations by maximizing agreement between augmentations of the same image.
- How SimCLR Uses Logits:
- Augmented image embeddings are compared using cosine similarity to compute logits.
- Contrastive loss ensures that embeddings of the same image augmentation are close while embeddings of different images are far apart.
(3) DINO: Self-Distillation for Vision
DINO (Self-Distillation with No Labels) uses contrastive loss to align representations between teacher and student networks.
- How DINO Uses Logits:
- DINO computes logits for embeddings generated by the teacher and student networks.
- Contrastive loss aligns the similarity distributions of the two networks across different augmentations of the same input.
4. Practical Example: From Logits to Contrastive Loss
Scenario: CLIP for Image-Text Alignment
Suppose we have two images and two texts. The embeddings are as follows:
- Image embeddings:
z image , 1 = [ 1 , 0 ] , z image , 2 = [ 0 , 1 ] z_{\text{image}, 1} = [1, 0], \quad z_{\text{image}, 2} = [0, 1] zimage,1=[1,0],zimage,2=[0,1] - Text embeddings:
z text , 1 = [ 1 , 0 ] , z text , 2 = [ 0 , 1 ] z_{\text{text}, 1} = [1, 0], \quad z_{\text{text}, 2} = [0, 1] ztext,1=[1,0],ztext,2=[0,1]
Step 1: Compute Logits
Using dot product similarity:
logits
=
[
1
0
0
1
]
\text{logits} = \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix}
logits=[1001]
Step 2: Apply Softmax
Convert logits to probabilities (with temperature (
τ
=
1
\tau = 1
τ=1 )):
P
(
text
1
∣
image
1
)
=
exp
(
1
)
exp
(
1
)
+
exp
(
0
)
≈
0.73
P(\text{text}_1|\text{image}_1) = \frac{\exp(1)}{\exp(1) + \exp(0)} \approx 0.73
P(text1∣image1)=exp(1)+exp(0)exp(1)≈0.73
P
(
text
2
∣
image
1
)
=
exp
(
0
)
exp
(
1
)
+
exp
(
0
)
≈
0.27
P(\text{text}_2|\text{image}_1) = \frac{\exp(0)}{\exp(1) + \exp(0)} \approx 0.27
P(text2∣image1)=exp(1)+exp(0)exp(0)≈0.27
Step 3: Compute Loss
For correct image-text pairs, the loss is:
L
=
−
1
2
(
log
(
0.73
)
+
log
(
0.73
)
)
≈
0.63
\mathcal{L} = -\frac{1}{2}\big(\log(0.73) + \log(0.73)\big) \approx 0.63
L=−21(log(0.73)+log(0.73))≈0.63
5. Insights and Future Directions
Insights
-
Logits Are Critical for Representations:
Logits reflect the raw similarity between embeddings and directly influence the optimization of contrastive loss. -
Robustness of Contrastive Learning:
Contrastive loss excels in both supervised and unsupervised tasks by leveraging relative relationships between samples.
Future Directions
With the expansion of large models, contrastive loss can further integrate multimodal data (e.g., audio, video) and optimize embeddings for diverse tasks. Techniques like adaptive temperature scaling and hybrid contrastive objectives are promising areas for innovation.
Conclusion
Contrastive loss is a cornerstone of modern large-model training. By leveraging logits to compute similarity distributions, models like CLIP, SimCLR, and DINO achieve remarkable alignment and representation learning. The continued evolution of contrastive techniques will likely play a key role in the development of next-generation AI systems.
后记
2024年12月13日21点18分于上海,在GPT4o大模型辅助下完成。