contrastive loss function (papers)

### 温度参数对监督对比损失函数的影响 在监督对比学习框架下,温度参数 \( \tau \) 是控制特征空间分布的关键超参数之一。它决定了正样本对之间的相似性和负样本对之间差异性的相对重要程度[^1]。 当温度参数较低时 (\( \tau \to 0 \)),模型倾向于更严格地区分正负样本对,这可能导致优化过程更加关注局部细节而忽略全局结构信息[^3]。这种情况下,虽然可能提高分类边界的清晰度,但也容易引发过拟合现象,尤其是在训练数据有限的情况下[^2]。 相反,在较高温度设置下 (\( \tau \to 1 \) 或更大),模型会放松对于正负样本区分的要求,从而允许更多噪声存在并促进泛化能力提升。然而,如果温度过高,则可能会削弱有效信号的作用,使得最终学到的表征质量下降[^4]。 因此,在实际应用过程中需要通过实验来寻找最佳平衡点以获得最优性能表现 。通常可以通过网格搜索或者随机搜索方法来进行调优操作 ,同时结合验证集上的指标评估结果做出决策 。 ```python import torch from torch.nn import functional as F def supervised_contrastive_loss(features, labels, temperature=0.1): """ Compute Supervised Contrastive Loss. Args: features: Tensor of shape (batch_size, feature_dim). labels: Tensor of shape (batch_size,) containing class indices. temperature: Float representing the temperature parameter. Returns: Scalar tensor containing the loss value. """ batch_size = features.shape[0] mask = torch.eq(labels.unsqueeze(1), labels.unsqueeze(0)).float() logits = torch.div(torch.matmul(features, features.T), temperature) exp_logits = torch.exp(logits - torch.max(logits, dim=1, keepdim=True)[0]) log_prob = logits - torch.log(exp_logits.sum(dim=1, keepdim=True)) mean_log_prob_pos = (mask * log_prob).sum(dim=1) / mask.sum(dim=1) loss = -(temperature / batch_size) * mean_log_prob_pos.sum() return loss ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值