1. Noise Contrastive Estimation (NCE)
在大规模数据集上训练语言模型时,计算输出层(例如softmax层)的概率分布需要大量的计算资源,因为需要对整个词汇表进行归一化计算。NCE通过将问题转化为二分类问题,避免了对整个词汇表的归一化,从而显著降低了计算成本。
softmax函数的表达式如下所示:
Softmax
(
x
i
)
=
e
x
i
∑
j
=
1
n
e
x
j
\text{Softmax}(x_i) = \frac{e^{x_i}}{\sum_{j=1}^{n}e^{x_j}}
Softmax(xi)=∑j=1nexjexi
可以看到,随着样本类别的增多,所需要的指数运算也会增加。
NCE的表达式如下所示:
NCELoss
=
−
1
N
∑
i
=
1
N
[
log
P
model
(
x
i
)
P
model
(
x
i
)
+
k
P
n
(
x
i
)
+
∑
j
=
1
k
log
k
P
n
(
x
i
j
)
P
model
(
x
i
j
)
+
k
P
n
(
x
i
j
)
]
\text{NCELoss} = -\frac{1}{N}\sum_{i=1}^{N}\left[\log \frac{P_{\text{model}}(x_i)}{P_{\text{model}}(x_i) + kP_n(x_i)} + \sum_{j=1}^{k} \log \frac{kP_n(x_{ij})}{P_{\text{model}}(x_{ij}) + kP_n(x_{ij})}\right]
NCELoss=−N1i=1∑N[logPmodel(xi)+kPn(xi)Pmodel(xi)+j=1∑klogPmodel(xij)+kPn(xij)kPn(xij)]
-
P m o d e l ( x i ) P_{model}(x_i) Pmodel(xi)代表的是模型输出的概率,它是一个二分类的概率,即模型判别当前正样本 x i x_i xi来自真实分布的概率,在NCE中,我们期望此概率应该尽可能大,因为我们希望模型能够精准地分辨出正样本。
-
P n ( x ) P_n(x) Pn(x)代表的是样本 x x x来自噪声分布的概率,最简单可以取均匀分布作为噪声分布,一般我们希望噪声分布能够逼近真实分布。
-
x i j x_{ij} xij表示从噪声分布中取的第 j j j个样本
-
k k k代表噪声样本的数量。
import torch
from torch import nn
eps = 1e-7
class NCECriterion(nn.Module):
def __init__(self, nLem):
super(NCECriterion, self).__init__()
self.nLem = nLem
def forward(self, x, targets):
# x shape: [batchSize, K+1]
# targets shape: [batchSize]
# K is the number of noise samples
batchSize = x.size(0)
K = x.size(1)-1
Pnt = 1 / float(self.nLem) # P(origin=noise)
Pns = 1 / float(self.nLem) # P(noise=sample)
# eq 5.1 : P(origin=model) = Pmt / (Pmt + k*Pnt)
Pmt = x.select(1,0) # 1st column is the model output
Pmt_div = Pmt.add(K * Pnt + eps)
lnPmt = torch.div(Pmt, Pmt_div)
# eq 5.2 : P(origin=noise) = k*Pns / (Pms + k*Pns)
Pon_div = x.narrow(1,1,K).add(K * Pns + eps) # 2nd to last column are noise samples
Pon = Pon_div.clone().fill_(K * Pns)
lnPon = torch.div(Pon, Pon_div)
# equation 6 in ref. A
lnPmt.log_()
lnPon.log_()
lnPmtsum = lnPmt.sum(0)
lnPonsum = lnPon.view(-1, 1).sum(0)
loss = - (lnPmtsum + lnPonsum) / batchSize
return loss
2. Information Noise-Contrastive Estimation(InfoNCE)
InfoNCE loss 是一种在自监督学习任务中常用的对比损失函数,特别是在对比学习(Contrastive Learning)框架下。它旨在拉近正样本(相似样本)之间的距离,同时推开负样本(不相似样本)之间的距离,以此来学习数据的有效表示。
表达式如下所示:
L
InfoNCE
=
−
E
[
log
exp
(
sim
(
x
,
x
+
)
)
exp
(
sim
(
x
,
x
+
)
)
+
∑
i
=
1
K
exp
(
sim
(
x
,
x
i
−
)
)
]
\mathcal{L}_{\text{InfoNCE}} = -\mathbb{E}\left[\log \frac{\exp(\text{sim}(x, x^+))}{\exp(\text{sim}(x, x^+)) + \sum_{i=1}^{K} \exp(\text{sim}(x, x^-_i))} \right]
LInfoNCE=−E[logexp(sim(x,x+))+∑i=1Kexp(sim(x,xi−))exp(sim(x,x+))]
-
s
i
m
(
x
,
y
)
sim(x,y)
sim(x,y)是样本之间的相似度,通常使用点积或余弦相似度计算。
infoNCE鼓励模型能够区分正负样本,从而获得更好的数据表征。
import torch
import torch.nn.functional as F
from torch import nn
class InfoNCE(nn.Module):
def __init__(self, temperature=0.1, reduction='mean', negative_mode='unpaired'):
super().__init__()
self.temperature = temperature
self.reduction = reduction
self.negative_mode = negative_mode
def forward(self, query, positive_key, negative_keys=None):
return info_nce(query, positive_key, negative_keys,
temperature=self.temperature,
reduction=self.reduction,
negative_mode=self.negative_mode)
def transpose(x):
return x.transpose(-2, -1)
def normalize(*xs):
# F.normalize()意思是对输入的张量进行标准化,即将张量的每个分量除以其范数。
return [None if x is None else F.normalize(x, dim=-1) for x in xs]
def info_nce(query, positive_key,
negative_keys=None, temperature=0.1,
reduction='mean', negative_mode='unpaired'):
"""
If negative_mode = 'paired', then negative_keys is a (N, M, D) Tensor. 即给每个query对应一组negative_keys
If negative_mode = 'unpaired', then negative_keys is a (M, D) Tensor. 即给每个query都是同一组negative_keys
"""
if query.dim() != 2:
raise ValueError('query must be 2D tensor')
if positive_key.dim() != 2:
raise ValueError('positive_key must be 2D tensor')
if negative_keys is not None:
if negative_mode == 'unpaired' and negative_keys.dim() != 2:
raise ValueError('negative_keys must be 2D tensor for negative_mode=unpaired')
if negative_mode == 'paired' and negative_keys.dim() != 3:
raise ValueError('negative_keys must be 3D tensor for negative_mode=paired')
# Check matching number of samples.
if len(query) != len(positive_key):
raise ValueError('<query> and <positive_key> must must have the same number of samples.')
if negative_keys is not None:
if negative_mode == 'paired' and len(query) != len(negative_keys):
raise ValueError("If negative_mode == 'paired', then <negative_keys> must have the same number of samples as <query>.")
# Embedding vectors should have same number of components.
if query.shape[-1] != positive_key.shape[-1]:
raise ValueError('Vectors of <query> and <positive_key> should have the same number of components.')
if negative_keys is not None:
if query.shape[-1] != negative_keys.shape[-1]:
raise ValueError('Vectors of <query> and <negative_keys> should have the same number of components.')
query, positive_key, negative_keys = normalize(query, positive_key, negative_keys)
if negative_keys is not None:
# Explicit negative keys
# Cosine between positive pairs
positive_logit = torch.sum(query * positive_key, dim=1, keepdim=True) # (N, 1)
if negative_mode == 'unpaired':
# Cosine between all query-negative combinations
negative_logits = query @ transpose(negative_keys) # (N, M)
elif negative_mode == 'paired':
query = query.unsqueeze(1) # (N, 1, D)
negative_logits = query @ transpose(negative_keys) # (N, 1, M)
negative_logits = negative_logits.squeeze(1) # (N, M)
logits = torch.cat([positive_logit, negative_logits], dim=1) # (N, 1+M)
labels = torch.zeros(len(logits), dtype=torch.long, device=query.device) # (N,)
else:
# Negative keys are implicitly off-diagonal positive keys.
# Cosine between all combinations
logits = query @ transpose(positive_key) # (N, N)
# Positive keys are the entries on the diagonal
labels = torch.arange(len(query), device=query.device)
return F.cross_entropy(logits / temperature, labels, reduction=reduction)
3. 区别
-
NCE通常用于语言模型和其他概率模型的训练中,特别是在处理大规模词汇表时,如自然语言处理中的词嵌入学习。
-
NCE的主要目标是通过将概率密度估计问题转化为分类问题来学习概率模型的参数。它通过对比来自数据分布的样本和来自已知噪声分布的样本,使得模型能够学习区分这两种样本。
-
InfoNCE主要用于自监督学习任务,如特征学习、图像和文本表示学习,以及最近在多模态学习中的应用。
-
InfoNCE损失函数的设计更侧重于学习表示,特别是在自监督学习和对比学习框架下。InfoNCE通过最大化正样本对之间的互信息,同时使得锚点样本与负样本对之间的信息量最小化,从而学习有效的数据表示。