信息熵和交叉熵损失函数

信息熵

信息熵描述的是随机变量不确定度的度量

  • 熵越小,数据的不确定性越低
  • 熵越大,数据的不确定性越高

k k k表示的类别, p i p_{i} pi是某一个类比的占比,计算公式如下

H = − ∑ i = 1 k p i log ⁡ ( p i ) H = - \sum_{i=1}^{k}p_{i}\log(p_{i}) H=i=1kpilog(pi)
x 1 x_1 x1信息熵大,三个事件发生的概率相同,可以理解数据的不确定性很高。
x 2 x_2 x2信息熵较小,三个事件中倾向发生事件三,数据的确定性较高(等价数据的不确定性较低)。
x 3 x_3 x3信息熵最小,三个事件中100%发生事件三,数据的确定性最高,理论上信息熵为0。

import numpy as np
import matplotlib.pyplot as plt


def entropy(p):
    ret = 0
    for i in range(len(p)):
        ret += p[i] * np.log(p[i])
    return -ret


x1 = np.array([0.333, 0.333, 0.333])
x1_entropy = entropy(x1)
print(x1_entropy)
# 1.0985131762126916

x2 = np.array([0.1, 0.2, 0.7])
x2_entropy = entropy(x2)
print(x2_entropy)
# 0.8018185525433373

x3 = np.array([0.001, 0.001, 0.998])
x3_entropy = entropy(x3)
print(x3_entropy)
# 0.015813509223296007

交叉熵损失函数

信息熵明白后再来理解交叉熵损失函数将会简单很多,甚至是一种很自然的想法。

在多分类问题中,假设我们有 K K K个类别,并且每个样本的标签是一个one-hot向量表示(即真实标签 y y y 是一个形如 ( 1 , 0 , . . . , 0 ) (1, 0, ..., 0) (1,0,...,0) 的向量,其中只有一个位置为1,对应样本的真实类别)。预测的概率分布由模型给出,记作 y ^ \hat y y^,它也是一个形如 ( p 1 , p 2 , . . . , p K ) (p_1, p_2, ..., p_K) (p1,p2,...,pK) 的向量,其中 p k p_k pk 表示样本属于第 k k k类的概率,并且满足
0 ≤ p k ≤ 1 及 ∑ k = 1 K p k = 1 0 \leq p_k \leq 1 及 \sum_{k=1}^{K} p_k = 1 0pk1k=1Kpk=1
交叉熵损失函数(Categorical Cross-Entropy Loss)可以定义为
l o s s = − 1 N ∑ i = 1 N ∑ k = 1 K y i , k log ⁡ ( p i , k ) loss = -\frac{1}{N} \sum_{i=1}^{N} \sum_{k=1}^{K} y_{i,k} \log(p_{i,k}) loss=N1i=1Nk=1Kyi,klog(pi,k)
其中
N 是样本数,  K 分类数 y i , k 是第 i 个样本类别 k 的真实标签,one-hot向量表示 p i , k 是第 i 个样本类别 k 的预测概率 \text{$N$是样本数, $K$分类数} \\ \text{$y_{i,k}$是第$i$个样本类别$k$的真实标签,one-hot向量表示} \\ \text{$p_{i,k}$是第$i$个样本类别$k$的预测概率} \\ N是样本数, K分类数yi,k是第i个样本类别k的真实标签,one-hot向量表示pi,k是第i个样本类别k的预测概率

我们可以将真实标签 y y y理解那个为1的类别其概率为100%,其他类别的概率都为0, y ^ \hat y y^ 是模型预测的各个类别的概率分布,两个向量的点乘求和(就是公式中里面那个求和计算)就是这个样本的一次交叉熵损失值了,这和信息熵几乎一样。

为了避免计算过程中出现 l o g ( 0 ) log(0) log(0) 导致的未定义或无穷大问题,在实际计算时通常会对预测概率进行平滑处理,设置一个很小的正值 ϵ \epsilon ϵ
p i , k = { ϵ p i , k = 0 1 − ϵ p i , k = 1 p_{i,k} = \begin{cases} \epsilon & p_{i,k} = 0 \\ 1-\epsilon & p_{i,k} = 1 \end{cases} pi,k={ϵ1ϵpi,k=0pi,k=1
这里的 ϵ \epsilon ϵ被称为“平滑项”,它保证了即使预测概率非常接近0或1,也能得到一个有限的交叉熵值。

代码实现

可以看出torch实现的交叉熵损失函数F.cross_entropy 和我们实现的结果是一致的(注意两者的实现公式稍有不同)
代码中cross_entropy函数对信息熵的计算和本文推导的公式稍有不同,本质一样,主要为了代码编写方便。

import numpy as np
import torch
import torch.nn.functional as F


def cross_entropy(x, y, reduction="mean"):
    """ :param x: torch.tensor 是模型预测原始输出, 未经过softmax处理
        :param y: torch,tensor 是真实标签的类别索引"""
    x_softmax = torch.softmax(x, dim=1)
    x_softmax = x_softmax.numpy()
    y = y.numpy()
    if reduction == "mean":
        ret = 0
        for i in range(x.shape[0]):
            # y[i]真实类别的索引
            # x_softmax[i] 第i个样本的所有预测概率
            # x_softmax[i][y[i]] 真实类别对应的预测概率值
            ret = ret + np.log(x_softmax[i][y[i]])
        return -ret / x.shape[1]
    elif reduction == "none":
        ret = []
        for i in range(x.shape[0]):
            ret.append(-np.log(x_softmax[i][y[i]]))
        return np.array(ret)


# 模型预测结果(构造了一个预测全部正确的数据,损失值最小)
predict = torch.tensor([[2, 4, 6],
                        [8, 4, 3],
                        [1, 5, 2]], dtype=torch.double)
# 两者等价,参考torch的接口设计,这里实现了第一种
# 称为 class indices方式
true_label = torch.tensor([2, 0, 1])  
# 称为 class probabilities方式
# true_label = torch.tensor([[0, 0, 1], [1, 0, 0], [0, 1, 0]], dtype=torch.double)  

reduction = "mean"
# reduction = "none"
loss = cross_entropy(predict, true_label, reduction=reduction)
print(loss)
# 0.0778534741320504

loss2 = F.cross_entropy(predict, true_label, reduction=reduction)
print(loss2)
# tensor(0.0779, dtype=torch.float64)

参考

torch.nn.CrossEntropyLoss 源码接口
torch.nn.functional.cross_entropy 源码接口
https://blog.csdn.net/weixin_42924890/article/details/135732692

  • 18
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值