simcse的lose函数理解

https://github.com/bojone/SimCSE

源代码如下:

def simcse_loss(y_true, y_pred):
    """用于SimCSE训练的loss
    """
    # 构造标签
    idxs = K.arange(0, K.shape(y_pred)[0])
    idxs_1 = idxs[None, :]
    idxs_2 = (idxs + 1 - idxs % 2 * 2)[:, None]
    y_true = K.equal(idxs_1, idxs_2)
    y_true = K.cast(y_true, K.floatx())
    # 计算相似度
    y_pred = K.l2_normalize(y_pred, axis=1)
    similarities = K.dot(y_pred, K.transpose(y_pred))
    similarities = similarities - tf.eye(K.shape(y_pred)[0]) * 1e12
    similarities = similarities * 20
    loss = K.categorical_crossentropy(y_true, similarities, from_logits=True)
    return K.mean(loss)

刚看可能觉得蒙圈,那是因为这块代码的逻辑要和数据加载的部分结合起来看:

class data_generator(DataGenerator):
    """训练语料生成器
    """
    def __iter__(self, random=False):
        batch_token_ids = []
        for is_end, token_ids in self.sample(random):
            batch_token_ids.append(token_ids)
            batch_token_ids.append(token_ids)
            if len(batch_token_ids) == self.batch_size * 2 or is_end:
                batch_token_ids = sequence_padding(batch_token_ids)
                batch_segment_ids = np.zeros_like(batch_token_ids)
                batch_labels = np.zeros_like(batch_token_ids[:, :1])
                yield [batch_token_ids, batch_segment_ids], batch_labels
                batch_token_ids = []

从数据加载部分的代码可以看出,加载的数据是a,a,b,b,c,c,d,d,...

所以,假设根据除了句子自己其他都是负样本的思想,可以得出,y_true应该为:

\begin{equation} \begin{bmatrix} A & 0 & 0 & ... & 0\\ 0 & A & 0 & ... & 0\\ ... & ... &... & ...&0\\ 0 & 0 & 0 & ... & A \end{bmatrix} \end{equation},其中A=\begin{equation} \begin{bmatrix} 0 & 1\\ 1 & 0 \end{bmatrix} \end{equation}

所以,

idxs = K.arange(0, K.shape(y_pred)[0])
idxs_1 = idxs[None, :]
idxs_2 = (idxs + 1 - idxs % 2 * 2)[:, None]
y_true = K.equal(idxs_1, idxs_2)
y_true = K.cast(y_true, K.floatx())

#以上代码等价于
idxs = K.arange(0, K.shape(y_pred)[0])
idx1 = idxs[None, :]
idx2 = idxs[:,None]
#y_true=tf.cast(tf.equal(tf.abs(tf.subtract(idx1,idx2)),1),dtype=tf.int32)
y_true=np.equal(np.abs(np.subtract(idx1,idx2)),1).astype(np.int32)

 标签构造理解了,接下来,对于相似性的计算就原理如下:

数据都和自己相似,由于bert模型本身增加了遮罩处理和dropout处理,所以即使同一个句子两次输入产生的向量也会有所差别,利用这个原理,将同一个向量相似计算(对角线)数据无穷小,以便于模型训练时此数据对模型训练的贡献减少;同一个句子两次数据的向量进行相似性计算,期望二则的更加相似。

对于 similarities = similarities - tf.eye(K.shape(y_pred)[0]) * 1e12代码,可以理解为将对角线元素无穷小。对角线即是和完全一样的子集的相似性比较,使其无穷小也就是忽略这块数据对模型优化的影响;

对于similarities = similarities * 20,这块是在忽略对角数据后的操作,表示将相似性方大20倍,这样可以使得模型更快的收敛。

SimCSE loss理解 - 知乎

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

会发paper的学渣

您的鼓励和将是我前进的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值