KL散度损失函数

2021SC@SDUSC

之前学习了信息熵损失函数,之后来学习KI散度损失函数

在我们使用的模型中,这个模型的输入样本和样本标签已定,它们所对应的真实分布概率也确定

KL散度KL divergence

全称:Kullback-Leibler Divergence。

用途:比较两个概率分布的接近程度。
在统计应用中,我们经常需要用一个简单的,近似的概率分布 * 来描述。

观察数据 D 或者另一个复杂的概率分布 。这个时候,我们需要一个量来衡量我们选择的近似分布 * 相比原分布 f 究竟损失了多少信息量,这就是KL散度起作用的地方。

熵(entropy)

想要考察信息量的损失,就要先确定一个描述信息量的量纲。

在信息论这门学科中,一个很重要的目标就是量化描述数据中含有多少信息。

为此,提出了的概念,记作 H 。

一个概率分布所对应的表达如下:

 

KL散度的计算

现在,我们能够量化数据中的信息量了,就可以来衡量近似分布带来的信息损失了。
KL散度的计算公式其实是熵计算公式的简单变形,在原有概率分布 p 上,加入我们的近似概率分布 ,计算他们的每个取值对应对数的差:

 

换句话说,KL散度计算的就是数据的原分布与近似分布的概率的对数差的期望值。

在对数以2为底时, log 2 ,可以理解为“我们损失了多少位的信息”。

写成期望形式:

 更常见的是以下形式:

 

现在,我们就可以使用KL散度衡量我们选择的近似分布与数据原分布有多大差异了。

散度不是距离

 因为KL散度不具有交换性,所以不能理解为“距离”的概念,衡量的并不是两个分布在空间中的远近,更准确的理解还是衡量一个分布相比另一个分布的信息损失(infomation lost)。

使用KL散度进行优化

通过不断改变预估分布的参数,我们可以得到不同的KL散度的值。

在某个变化范围内,KL散度取到最小值的时候,对应的参数是我们想要的最优参数。

这就是使用KL散度优化的过程。

KL散度=交叉熵-信息熵

 

        # unsup loss
        if unsup_batch:
            # ori
            with torch.no_grad():
                ori_logits = model(ori_input_ids, ori_segment_ids, ori_input_mask)
                ori_prob   = F.softmax(ori_logits, dim=-1)    # KLdiv target
                # ori_log_prob = F.log_softmax(ori_logits, dim=-1)

                # confidence-based masking
                if cfg.uda_confidence_thresh != -1:
                    unsup_loss_mask = torch.max(ori_prob, dim=-1)[0] > cfg.uda_confidence_thresh
                    unsup_loss_mask = unsup_loss_mask.type(torch.float32)
                else:
                    unsup_loss_mask = torch.ones(len(logits) - sup_size, dtype=torch.float32)
                unsup_loss_mask = unsup_loss_mask.to(_get_device())
                    
            # aug
            # softmax temperature controlling
            uda_softmax_temp = cfg.uda_softmax_temp if cfg.uda_softmax_temp > 0 else 1.
            aug_log_prob = F.log_softmax(logits[sup_size:] / uda_softmax_temp, dim=-1)

            # KLdiv loss
            """
                nn.KLDivLoss (kl_div)
                input : log_prob (log_softmax)
                target : prob    (softmax)
                https://pytorch.org/docs/stable/nn.html

                unsup_loss is divied by number of unsup_loss_mask
                it is different from the google UDA official
                The official unsup_loss is divided by total
                https://github.com/google-research/uda/blob/master/text/uda.py#L175
            """
            unsup_loss = torch.sum(unsup_criterion(aug_log_prob, ori_prob), dim=-1)
            unsup_loss = torch.sum(unsup_loss * unsup_loss_mask, dim=-1) / torch.max(torch.sum(unsup_loss_mask, dim=-1), torch_device_one())
            final_loss = sup_loss + cfg.uda_coeff*unsup_loss

            return final_loss, sup_loss, unsup_loss
        return sup_loss, None, None

  • 16
    点赞
  • 44
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值