引言
思考了半天苏神的SimCSE(https://github.com/bojone/SimCSE/blob/main/eval.py),遂记录在此,帮助有需要的同学理解。
原文无监督方法
在原文的a图中,我们可以知道在一个batch中,输入的n个文本都是不相关的,而正样本对是通过对相同样本dropout2次得到的,每一次使用不同的dropout mask。
正负样本对构造方法
我们可以发现,在生成样本数据的时候,每一个相同的样本连续生成了2次,由于每个样本使用的dropout mask不一样,因此在一个minibatch中,连续的两个样本形成正样本对,非连续的样本对形成负样本对
SimCSE损失
def simcse_loss(y_true, y_pred):
"""用于SimCSE训练的loss
"""
# 构造标签
idxs = K.arange(0, K.shape(y_pred)[0])
idxs_1 = idxs[None, :]
idxs_2 = (idxs + 1 - idxs % 2 * 2)[:, None]
y_true = K.equal(idxs_1, idxs_2)
y_true = K.cast(y_true, K.floatx())
# 计算相似度
y_pred = K.l2_normalize(y_pred, axis=1)
similarities = K.dot(y_pred, K.transpose(y_pred))
similarities = similarities - tf.eye(K.shape(y_pred)[0]) * 1e12
similarities = similarities * 20
loss = K.categorical_crossentropy(y_true, similarities, from_logits=True)
return K.mean(loss)
核心的损失代码如上所示,上述代码想要构造出如下的正确标签:
我们可以通过构建一个0到n的数组,并将两两前后反转的方式,得到上述正确标签:
这是上述代码y_true的构造方式,由于我们不需要自己与自己相似度,即对角线上的值,因此采用将logits置为负无穷,使得其指数为0,对损失无影响,similarities = similarities - tf.eye(K.shape(y_pred)[0]) * 1e12
。
对比损失
对比损失的核心公式如图所示,其本质上是一个交叉熵损失:
C
E
=
−
l
o
g
e
t
∑
i
e
i
CE = -log\frac{e^t}{\sum_ie^i}
CE=−log∑ieiet
因此可以在计算好相似度除以温度系数之后,采用交叉熵损失的方式来进行计算