SotMax函数的对数版本Log-SoftMax

Softmax函数是一种将任意实数的向量归一化成概率分布的函数,其中每个元素的值代表该类的概率,并且所有元素的和为1。对数版本的Softmax(通常称为Log-Softmax)在多分类的深度学习模型中经常使用,尤其是在计算损失函数时。

Softmax函数

Softmax函数定义为:
Softmax ( z i ) = e z i ∑ j e z j \text{Softmax}(z_i) = \frac{e^{z_i}}{\sum_{j} e^{z_j}} Softmax(zi)=jezjezi
其中, z z z 是输入向量, z i z_i zi 是向量中的第 i i i个元素,分母是对整个输入向量所有元素应用指数函数后的和。

Log-Softmax函数

Softmax 函数的对数版本,它用于正规化 logits 矩阵,使得每行的元素和为 1(在指数和对数的作用下)。这样处理后的结果可以直接用于计算对数似然,这是监督对比损失中的一个关键部分。
对Softmax的结果取对数得到Log-Softmax,其表达式为:
Log-Softmax ( z i ) = log ⁡ ( e z i ∑ j e z j ) \text{Log-Softmax}(z_i) = \log\left(\frac{e^{z_i}}{\sum_{j} e^{z_j}}\right) Log-Softmax(zi)=log(jezjezi)
利用对数的性质,这可以简化为:
Log-Softmax ( z i ) = z i − log ⁡ ( ∑ j e z j ) \text{Log-Softmax}(z_i) = z_i - \log\left(\sum_{j} e^{z_j}\right) Log-Softmax(zi)=zilog(jezj)
这里, z i z_i zi 是输入向量的第 i i i 个元素,而 ∑ j e z j \sum_{j} e^{z_j} jezj 是输入向量的所有元素的指数和的对数。这种形式的好处是可以直接从logits(即模型输出的未经归一化的预测值)计算出来,而不需要先计算Softmax概率再取对数。

数值稳定性

在实际应用中,直接计算 e z i e^{z_i} ezi 可能导致数值溢出,尤其是当 z i z_i zi 的值很大时。为了避免这种情况,常常从每个 z i z_i zi 中减去 z z z 的最大值:
Log-Softmax ( z i ) = z i − max ⁡ ( z ) − log ⁡ ( ∑ j e z j − max ⁡ ( z ) ) \text{Log-Softmax}(z_i) = z_i - \max(z) - \log\left(\sum_{j} e^{z_j - \max(z)}\right) Log-Softmax(zi)=zimax(z)log(jezjmax(z))
这种处理方式不影响结果,但可以极大地提高数值稳定性。

应用于损失函数

在使用监督对比损失(如交叉熵损失)时,Log-Softmax非常有用。交叉熵损失可以定义为:
Cross-Entropy ( y , y ^ ) = − ∑ i y i log ⁡ ( y ^ i ) \text{Cross-Entropy}(y, \hat{y}) = -\sum_{i} y_i \log(\hat{y}_i) Cross-Entropy(y,y^)=iyilog(y^i)
其中 y y y 是真实标签的one-hot编码, y ^ \hat{y} y^ 是预测概率。如果使用Log-Softmax,损失可以直接使用模型输出的logits来计算,形式为:
Cross-Entropy ( y , z ) = − ∑ i y i ( Log-Softmax ( z ) i ) \text{Cross-Entropy}(y, z) = -\sum_{i} y_i (\text{Log-Softmax}(z)_i) Cross-Entropy(y,z)=iyi(Log-Softmax(z)i)
这种方法避免了显式计算概率分布,直接利用logits进行计算,从而提高了效率并保持了数值的稳定性。

总结

Log-Softmax在处理大规模多类分类问题时非常有用,尤其是在模型需要处理大量输出类别时。它使得从logits到损失计算的过程更直接、更高效,并且通过避免直接的概率计算,降低了因数值问题导致的误差。

应用

例如,监督对比损失,代码如下:

import torch
import torch.nn as nn

class SupConLoss(nn.Module):
    def __init__(self, device):
        super(SupConLoss, self).__init__()
        self.device = device
        self.temperature = 1.0
    def forward(self, text_features, image_features, t_label, i_targets): 
        # 文本特征数量
        batch_size = text_features.shape[0] 
        # 图像特征数量
        batch_size_N = image_features.shape[0] 
        # 找出属于同一个ID标签的文本和图像特征,可作为文本和图像特征的相似度标签
        mask = torch.eq(t_label.unsqueeze(1).expand(batch_size, batch_size_N), \
            i_targets.unsqueeze(0).expand(batch_size,batch_size_N)).float().to(self.device) 

        # 计算文本和图像特征的相似度
        logits = torch.div(torch.matmul(text_features, image_features.T),self.temperature)
        # for numerical stability 为了提高计算稳定性
        logits_max, _ = torch.max(logits, dim=1, keepdim=True)
        logits = logits - logits_max.detach() 
        # 计算logits的指数
        exp_logits = torch.exp(logits) 
        # 计算log-softmax, 用于正规化 logits 矩阵,使得每行的元素和为 1
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 
        # 首先使用前面创建的 mask 掩码矩阵来选择出相同标签的特征对
        # 之后,对每行进行求和,得到每个文本特征对应的所有正样本的 log 概率和
        # 最后,将这些和除以每个文本特征对应的正样本数量,得到每个文本特征对其所有正样本的平均对数概率。
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 
        # 将所有文本特征的平均对数概率取负值后求平均
        loss = - mean_log_prob_pos.mean()

在计算损失函数的过程中,从“计算每对特征的指数值”开始的步骤是关键环节,它们共同构成了对比损失函数的核心。下面是这些步骤的详细介绍:

计算每对特征的指数值

exp_logits = torch.exp(logits)

在这一步中,对之前计算并校正的 logits 矩阵的每个元素应用指数函数。logits 矩阵的每个元素表示一对特征(一个来自文本特征,一个来自图像特征)之间的相似度,经过温度调节后的点积。应用指数函数是为了将这些相似度转换为非负权重,这些权重反映了特征对之间的相对接近程度。在概率论中,这种转换常见于 Softmax 函数的计算中,它有助于后续的归一化处理。

计算log概率

log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

接下来,需要计算每个 logits 对应的 log 概率。首先,使用 torch.logexp_logits 矩阵的行求和,得到的是每个文本特征与所有图像特征的指数权重和的对数。然后,用 logits 矩阵中的每个元素减去这个对数和,得到的 log_prob 表示在给定文本特征的情况下,每个图像特征作为正样本的对数概率。

这一步骤是 Softmax 函数的对数版本,它用于正规化 logits 矩阵,使得每行的元素和为 1(在指数和对数的作用下)。这样处理后的结果可以直接用于计算对数似然,这是监督对比损失中的一个关键部分。

计算正样本对的平均对数概率

mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

在这一步中,首先使用前面创建的 mask 掩码矩阵来选择出相同标签的特征对,mask 中的元素是 1(如果标签匹配)或 0(如果标签不匹配)。然后,将掩码与 log_prob 对应元素相乘,这样只有相同类别的特征对的 log 概率被保留,其余的变为 0。

之后,对每行进行求和,得到每个文本特征对应的所有正样本的 log 概率和。最后,将这些和除以每行掩码值的和(即每个文本特征对应的正样本数量),得到的 mean_log_prob_pos 是每个文本特征对其所有正样本的平均对数概率。

计算最终损失值

loss = -mean_log_prob_pos.mean()

最后一步是取 mean_log_prob_pos 的负平均值作为最终的损失值。这一步将所有文本特征的平均对数概率取负值后求平均,目的是使得模型在训练过程中尽可能地将相同标签的文本特征和图像特征映射得更接近,而使不同标签的特征映射得更远。这种损失函数促使模型学习到能够区分不同类别的有效特征表示。

这些步骤共同构成了监督对比学习的损失计算过程,使得学习到的特征不仅在单个模态内部具有区分性,而且能在不同模态间实现良好的对齐和匹配。

  • 27
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

yiruzhao

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值