手撕算法笔记--手撕交叉熵损失函数和对比学习InfoNCE loss

本文介绍了二元交叉熵和多元交叉熵的计算方法,以及如何在Python中使用numpy实现,通过Cal_BCE和Cal_MCE函数分别处理二分类和多分类问题,还展示了独热编码的应用。
摘要由CSDN通过智能技术生成

1--交叉熵定义

1-1--二元交叉熵

Loss(y_{true},y_{pred}) = -[y_{true}*log(y_{pred}) + (1-y_{true})*log(1-y_{pred})]

其中,y_pred表示模型预测的概率值,y_true表示真实的类别标签。

1-2--多元交叉熵

Loss = -\frac{1}{N}\sum_{i=1}^{N}\sum_{j=1}^{M}y_{true_{ij}}\cdot log(y_{pred_{ij}})

其中,N表示batchsize的大小,M表示类别的个数。

2--代码

'''
@File    :   Cal_CE_Loss.py
@Time    :   2024/03/27 20:13:00
@Author  :   Jinfu Liu
@Version :   1.0
@Desc    :   Calculate Cross Entropy Loss
'''

import numpy as np

def Cal_BCE(y_true, y_pred):
    """计算二元交叉熵
    Args:
        y_true: 真实标签 [B]
        y_pred: 预测概率 [B]
    Return:
        二元交叉熵损失
    """
    ce_loss = y_true*(np.log(y_pred)) + (1-y_true)*(np.log(1-y_pred)) # B
    total_ce = np.sum(ce_loss) # 1
    bce_loss = -total_ce / y_true.shape[0] # 1
    return bce_loss

def Cal_MCE(y_true, y_pred):
    """计算多元交叉熵
    Args:
        y_true: 真实标签 [B, C]
        y_pred: 预测概率 [B, C]
    Return:
        多元交叉熵损失
    """
    ce_loss = y_true*np.log(y_pred) # [B, C]
    total_ce = np.sum(ce_loss) # 1
    mce_loss = -total_ce / y_true.shape[0] # 1
    return mce_loss

def Cal_oneHot(y_true:np.ndarray, num_label:int):
    """计算独热向量
    Args:
        y_true: 真实标签
        num_label: 类别数
    Return:
        独热向量
    """
    one_hot = np.zeros((y_true.shape[0], num_label)) # [B, C]
    for idx, val in enumerate(y_true):
        one_hot[idx][val] = 1
    return one_hot

if __name__ == "__main__":
    y_true = np.array([0, 1, 1, 0, 1, 0])
    y_pred = np.array([0.3, 0.7, 0.8, 0.5, 0.6, 0.4])
    bce_loss = Cal_BCE(y_true = y_true, y_pred = y_pred)
    print("BCE_Loss: ", bce_loss)
    
    y_true = np.array([0, 1, 2, 3, 1])
    y_true_oneHot = Cal_oneHot(y_true, num_label = 4)
    y_pred = np.array([[0.7, 0.1, 0.1, 0.1],
                       [0.1, 0.7, 0.1, 0.1],
                       [0.1, 0.1, 0.7, 0.1],
                       [0.1, 0.1, 0.1, 0.7],
                       [0.1, 0.7, 0.1, 0.1]])
    mce_loss = Cal_MCE(y_true = y_true_oneHot, y_pred = y_pred)
    print("MCE_Loss: ", mce_loss)
    

2--InfoNCE loss

import torch
import torch.nn.functional as F

# 参考 https://zhuanlan.zhihu.com/p/506544456,labels设定参考评论的讲解

# 第一种写法,分母不包含正样本
def info_nce_loss(anchor_features, positive_features, temperature):
    batch_size, feature_dim = anchor_features.shape # B C

    # L2归一化
    anchor_normalized = F.normalize(anchor_features, dim = 1) # B C
    positive_normalized = F.normalize(positive_features, dim = 1) # B C
 
    # 计算锚点和正样本之间的相似度(余弦相似度)
    positive_logits = torch.einsum('nc,nc->n', [anchor_normalized, positive_normalized]).unsqueeze(1) / temperature # B 1

    # 计算锚点和所有负样本之间的相似度
    negative_logits = torch.einsum('nc,kc->nk', [anchor_normalized, positive_normalized]) / temperature # B B

    # 在负样本logits中去除每个样本自己的部分,因为自己不能作为自己的负样本
    new_negative_logits = torch.zeros(negative_logits.shape[0], negative_logits.shape[1] - 1) # B (B-1)
    for i in range(negative_logits.shape[0]):
        new_negative_logits[i] = torch.cat((negative_logits[i, :i], negative_logits[i, i+1:]), dim = 0) 

    # 将正样本logits和负样本logits合并为logits矩阵,正样本logits为第一列
    logits = torch.cat([positive_logits, new_negative_logits], dim=1) # B (1+B-1) # 正样本分数全放在第一列,因此下面的标签为0

    # 创建目标标签,正样本的索引是0
    labels = torch.zeros(batch_size).to(dtype=torch.long).to(anchor_features.device)

    # 计算交叉熵损失
    loss = F.cross_entropy(logits, labels)

    return loss

# 第二种写法,分母包含正样本
def info_nce_loss2(anchor_features, positive_features, temperature):
    batch_size, feature_dim = anchor_features.shape # B C

    # L2归一化
    anchor_normalized = F.normalize(anchor_features, dim = 1) # B C
    positive_normalized = F.normalize(positive_features, dim = 1) # B C
 
    # 计算锚点和正样本之间的相似度(余弦相似度)
    positive_logits = torch.einsum('nc,nc->n', [anchor_normalized, positive_normalized]).unsqueeze(1) / temperature # B 1

    # 计算锚点和所有负样本之间的相似度
    negative_logits = torch.einsum('nc,kc->nk', [anchor_normalized, positive_normalized]) / temperature # B B

    # 将正样本logits和负样本logits合并为logits矩阵,正样本logits为第一列
    logits = torch.cat([positive_logits, negative_logits], dim=1) # B (1+B) # 正样本分数全放在第一列,因此下面的标签为0

    # 创建目标标签,正样本的索引是0
    labels = torch.zeros(batch_size).to(dtype=torch.long).to(anchor_features.device)

    # 计算交叉熵损失
    loss = F.cross_entropy(logits, labels)

    return loss

if __name__ == "__main__":
    batch_size = 32  # 批次大小
    feature_dim = 128  # 特征维度
    temperature = 0.1  # 温度参数

    # 假设这些特征是通过编码器生成的
    anchor_features = torch.randn(batch_size, feature_dim)
    positive_features = torch.randn(batch_size, feature_dim)

    # 计算InfoNCE loss
    loss1 = info_nce_loss(anchor_features, positive_features, temperature)

    loss2 = info_nce_loss2(anchor_features, positive_features, temperature)
    
    print("loss1: ", loss1)
    print("loss2: ", loss2)

    print("All Done!")

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值