本文只考虑基本情况,未考虑加权。
torch.nnCrossEntropyLosss使用的公式
l
o
s
s
(
x
,
c
l
a
s
s
)
=
−
l
o
g
(
e
x
p
(
x
[
c
l
a
s
s
]
∑
j
e
x
p
(
x
[
j
]
)
)
loss(x,class)=-log(\frac {exp(x[class]} {\sum_jexp(x[j])})
loss(x,class)=−log(∑jexp(x[j])exp(x[class])
=
−
x
[
c
l
a
s
s
]
+
l
o
g
(
∑
j
e
x
p
(
x
[
j
]
)
)
(1)
=-x[class]+log(\sum_jexp(x[j])) \tag {1}
=−x[class]+log(j∑exp(x[j]))(1)目标类别采用one-hot编码
其中,class表示当前样本类别在one-hot编码中对应的索引(从0开始),
x[j]表示预测函数的第j个输出
公式(1)表示先对预测函数使用softmax计算每个类别的概率,再使用log(以e为底)计算后的相反数表示当前类别的损失,只表示其中一个样本的损失计算方式,非全部样本。
每个样本使用one-hot编码表示所属类别时,只有一项为1,因此与基本的交叉熵损失函数相比,省略了其它值为0的项,只剩(1)所表示的项。
【sample】
已知条件:共3种类别,输入两个样本,第一个样本为类别class=0,第二个样本为类别class=2
预测函数输出:
[
[
0.0541
,
0.1762
,
0.9489
]
,
[
−
0.0288
,
−
0.8072
,
0.4909
]
]
[[\;\;\;0.0541, \;\;\;0.1762, 0.9489], \\ \quad \quad \quad\quad\quad\quad[-0.0288, -0.8072, 0.4909]]
[[0.0541,0.1762,0.9489],[−0.0288,−0.8072,0.4909]],shape为2行3列
基于此,计算损失:
首先softmax计算两个样本对应类别的概率:
e
0.0541
e
0.0541
+
e
0.1762
+
e
0.9489
=
0.2185
\frac {e^{0.0541}} {e^{0.0541} + e^{0.1762} + e^{0.9489}} =0.2185
e0.0541+e0.1762+e0.9489e0.0541=0.2185
e
0.4909
e
−
0.0288
+
e
−
0.8072
+
e
0.4909
=
0.5354
\frac {e^{0.4909}} {e^{-0.0288} + e^{-0.8072} + e^{0.4909}} =0.5354
e−0.0288+e−0.8072+e0.4909e0.4909=0.5354
然后计算log之后的相反数:
−
l
o
g
(
0.2185
)
=
1.5210
-log(0.2185) = 1.5210
−log(0.2185)=1.5210
−
l
o
g
(
0.5354
)
=
0.6247
-log(0.5354) = 0.6247
−log(0.5354)=0.6247
取均值:
1.5210
+
0.6247
2
=
1.073
\frac {1.5210+0.6247}{2}=1.073
21.5210+0.6247=1.073
【torch.nn.CrossEntropyLoss使用流程】
torch.nn.CrossEntropyLoss为一个类,并非单独一个函数,使用到的相关简单参数会在使用中说明,并非对所有参数进行说明。
首先创建类对象
In [1]: import torch
In [2]: import torch.nn as nn
In [3]: loss_function = nn.CrossEntropyLoss(reduction="none")
参数reduction默认为"mean",表示对所有样本的loss取均值,最终返回只有一个值
参数reduction取"none",表示保留每一个样本的loss
计算损失
In [4]: pred = torch.tensor([[0.0541,0.1762,0.9489],[-0.0288,-0.8072,0.4909]], dtype=torch.float32)
In [5]: class_index = torch.tensor([0, 2], dtype=torch.int64)
In [6]: loss_value = loss_function(pred, class_index)
In [7]: loss_value
Out[7]: tensor([1.5210, 0.6247]) # 与上述【sample】计算一致
实际计算损失值调用函数时,传入pred预测值与class_index类别索引
在传入每个类别时,class_index应为一维,长度为样本个数,每个元素表示对应样本的类别索引,非one-hot编码方式传入
【测试torch.nn.CrossEntropyLoss的reduction参数为默认值"mean"】
In [1]: import torch
In [2]: import torch.nn as nn
In [3]: loss_function = nn.CrossEntropyLoss(reduction="mean")
In [4]: pred = torch.tensor([[0.0541,0.1762,0.9489],[-0.0288,-0.8072,0.4909]], dtype=torch.float32)
In [5]: class_index = torch.tensor([0, 2], dtype=torch.int64)
In [6]: loss_value = loss_function(pred, class_index)
In [7]: loss_value
Out[7]: 1.073 # 与上述【sample】计算一致