CoSENT的loss部分学习

CoSENT的loss部分学习:
以pytorch版本为例,里面写的很详细,虽然能看懂苏神写的loss原理,可是在刚开始看到loss的实现时看不懂,仔细研读才发现原来是这么回事,不得不佩服大佬们的代码能力太强了,自己是想不到,唉,劝退了一波。。这里作为自己学习的一个记录吧,有错误的地方欢迎指正。

https://github.com/shawroad/CoSENT_Pytorch

def calc_loss(y_true, y_pred):
    # 1. 取出真实的标签
    y_true = y_true[::2]    # tensor([1, 0, 1]) 真实的标签

    # 2. 对输出的句子向量进行l2归一化   后面只需要对应为相乘  就可以得到cos值了
    norms = (y_pred ** 2).sum(axis=1, keepdims=True) ** 0.5
    # y_pred = y_pred / torch.clip(norms, 1e-8, torch.inf)
    y_pred = y_pred / norms

    # 3. 奇偶向量相乘
    y_pred = torch.sum(y_pred[::2] * y_pred[1::2], dim=1) * 20

    # 4. 取出负例-正例的差值
    y_pred = y_pred[:, None] - y_pred[None, :]  # 这里是算出所有位置 两两之间余弦的差值
    # 矩阵中的第i行j列  表示的是第i个余弦值-第j个余弦值
    y_true = y_true[:, None] < y_true[None, :]   # 取出负例-正例的差值
    y_true = y_true.float()
    y_pred = y_pred - (1 - y_true) * 1e12
    y_pred = y_pred.view(-1)
    if torch.cuda.is_available():
        y_pred = torch.cat((torch.tensor([0]).float().cuda(), y_pred), dim=0)  # 这里加0是因为e^0 = 1相当于在log中加了1
    else:
        y_pred = torch.cat((torch.tensor([0]).float(), y_pred), dim=0)  # 这里加0是因为e^0 = 1相当于在log中加了1
        
    return torch.logsumexp(y_pred, dim=0)

传入的参数为y_truey_pred,其中y_true是标签组成的tensor,这里在前面load_data的过程中,label被load两次。所以将1个bs传入计算loss过程中,y_true是(bs*2, )。

1.取出真实的标签

y_true = y_true[::2]

为了不重复取出,就是将label跳2取,所以取出的是真实地标签。tensor([1,0,1,0…]) 维度为(bs,)

2.对输出的句子向量进行l2归一化 后面只需要对应为相乘 就可以得到cos值了
norms = (y_pred ** 2).sum(axis=1, keepdims=True) ** 0.5
    # y_pred = y_pred / torch.clip(norms, 1e-8, torch.inf)
    y_pred = y_pred / norms

y_pred传的是经过Encoder后的向量。这里在前面的load_data过程是直接将问题对extend进列表的,因此列表为1维。所以在这里y-pred的维度为(bs*2, emb_dim)。经过归一化后,维度不变。

3.奇偶向量相乘
y_pred = torch.sum(y_pred[::2] * y_pred[1::2], dim=1) * 20

因为y_pred是一维的,而句子对是按前后顺序存进去的。所以y_pred里面是[句子A1,句子A2,句子B1,句子B2].字母代表同一对,数字代表前后关系。所以分别跳2取,就是将两个句子分开了,然后计算cos。这里结束y_pred的维度变成(bs,)。每个值代表这两句话的cos值。

4.取出负例-正例的差值
y_pred = y_pred[:, None] - y_pred[None, :] 

重点开始。因为要使相似对的cos值大于不相似对的cos。所以这里先对bs内的每两个句子对的cos计算差值(先不考虑负减正还是正减负)。得到的y_pred维度为(bs,bs)。(i,j)代表第i个样本对的cos减去第j个样本对的cos。

y_true = y_true[:, None] < y_true[None, :]   # 取出负例-正例的差值
    y_true = y_true.float()
    y_pred = y_pred - (1 - y_true) * 1e12
    y_pred = y_pred.view(-1)

上面没考虑这个cos值是属于相似的(正)还是不相似的(负)。所以利用y_true来识别对应位置的cos是正还是负。同样y_true变成了(bs,bs)。(i,j)代表第i个样本对是否小于第j个样本对,小于就是负减正(所以为True)否则就是其他类型的,后面不考虑(所以False)。
执行完第一行后的y_true,里面为1的位置就代表了该位置是负例-正例。也就是loss公式里需要优化的。而为0的是其他的,不考虑。
因此乘以1e12是进一步放大y_true中为0的部分。
y_pred变成,只保留负例减正例的值,而其他值为负无穷大,到做幂运算时,就变为0了。最后y_pred展开为1维。(bs*bs, )有些值为0,不需要优化的,剩下的代表着bs内负例-正例的cos值,需要尽可能地变小。

y_pred = torch.cat((torch.tensor([0]).float().cuda(), y_pred), dim=0)  # 这里加0是因为e^0 = 1相当于在log中加了1
torch.logsumexp(y_pred, dim=0)

总而言之,这里的y_pred一开始为(bs*2,768),在view之前需要得到一个(bs,bs)的矩阵,里面有值的地方(i,j)对应的是负对-正对的cos,其他位置为负无穷。view(-1)之后就变为了(bs*bs,),还要加个1,所以利用e的零次幂为1来cat在(bs*bs,)后变成(bs*bs+1,)。然后执行logsumexp,将y_pred的每个值x变为e的x次幂,然后在维度0上执行相加操作并返回log,作为最终loss值。

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值