做分类任务经常用到 torch.nn.CrossEntropyLoss() 这个损失函数,之前只是简单的使用,但是对于其内部原理并不了解,现在学习一下内部原理
softmax
先从softmax函数开始
在softmax的机制中,为获得输出层的输出(即最终的输出),我们不是将sigmoid()作用其上,而是采用的所谓softmax()。其输入为向量,输出为0~1之间的向量,其和为1。在分类任务中作为概率出现在交叉熵损失函数中。
s
o
f
t
m
a
x
(
x
)
=
e
x
i
∑
j
e
x
j
softmax(x) = \frac{e^{x_i}} {\sum_{j} e^{x_j}}
softmax(x)=∑jexjexi
softmax可导
∂
s
o
f
t
m
a
x
(
x
i
)
∂
x
i
=
e
x
i
∗
∑
j
e
x
j
−
e
x
i
∗
e
x
i
(
∑
j
e
x
j
)
2
=
e
x
i
∑
j
e
x
j
∗
∑
j
e
x
j
−
e
x
i
∑
j
e
x
j
=
s
o
f
t
m
a
x
(
x
)
∗
(
1
−
s
o
f
t
m
a
x
(
x
)
)
\frac {\partial softmax(x_i)} {\partial x_i} = \frac { e^{x_i} * \sum_{j} e^{x_j} - e^{x_i} * e^{x_i} } { ( \sum_{j} e^{x_j})^2 } \\ = \frac{ e^{x_i} } { \sum_{j} e^{x_j} } *\frac { \sum_{j} e^{x_j} - e^{x_i} } { \sum_{j} e^{x_j} } \\ = softmax(x) * ( 1 - softmax(x) )
∂xi∂softmax(xi)=(∑jexj)2exi∗∑jexj−exi∗exi=∑jexjexi∗∑jexj∑jexj−exi=softmax(x)∗(1−softmax(x))
代码实现:
import numpy as np
from torch.nn.functional import softmax
arr = np.array([0.5, 0.9, 3, 6, -8])
np_softmax = np.exp(arr) / np.sum(np.exp(arr))
print(np_softmax, np_softmax.sum())
arr_tensor = torch.Tensor(arr)
torch_softmax = softmax(arr_tensor, dim=0)
print(torch_softmax, torch.sum(torch_softmax))
[3.85554872e-03 5.75180280e-03 4.69701989e-02 9.43421665e-01 7.84482209e-07] 1.0
tensor([3.8555e-03, 5.7518e-03, 4.6970e-02, 9.4342e-01, 7.8448e-07]) tensor(1.)
logsoftmax()
logsoftmax() 就是在softmax() 之后对结果再做一次log操作,表达式如下:
l
o
g
s
o
f
t
m
a
x
(
x
)
=
l
o
g
e
x
i
∑
j
e
x
j
logsoftmax(x) = log \frac { e^{x_i} } { \sum_{j} e^{x_j} }
logsoftmax(x)=log∑jexjexi
NLLLoss()
负对数似然函数(Negtive Log Likehood)
n
l
l
l
o
s
s
=
−
1
N
∑
i
=
1
N
y
i
(
l
o
g
s
o
f
t
m
a
x
)
nllloss = -\frac{1} {N} \sum_{i=1}^{N} y_i (logsoftmax)
nllloss=−N1i=1∑Nyi(logsoftmax)
其中y_i是target经one_hot编码之后的数据标签 (实际在用封装好的函数时,无需传入one_hot编码)
代码实现:
import torch
import torch.nn.functional as F
import numpy as np
# input = torch.randn((2, 3))
input = torch.tensor([[ 1.1041, 0.7217, 1.1316],
[ 0.1365, -0.5008, -1.7217]])
target = torch.tensor([0, 2])
# 设置num_classes参数是为了和input保持同一shape
one_hot = F.one_hot(target, num_classes=3).float()
# 这里要注意torch.sum之后要变换形状
softmax_mine = torch.exp(input) / torch.sum(torch.exp(input), dim=1).reshape((-1, 1))
logsoftmax_mine = torch.log(softmax_mine)
nllloss = -torch.sum(one_hot*logsoftmax_mine)/target.shape[0]
print(nllloss)
# ===================== #
# 用torch.nn.functional进行验证
logsoftmax_f = F.log_softmax(input, dim=1)
nllloss_f = F.nll_loss(logsoftmax_f, target)
print(nllloss_f)
# ===================== #
# 直接用 Cross_Entropy 验证
cross_entropy = F.cross_entropy(input, target)
print(cross_entropy)
out >>> tensor(1.6884)
tensor(1.6884)
tensor(1.6884)
reference:
吃透torch.nn.CrossEntropyLoss()