交叉熵损失函数
多分类的交叉熵损失函数:
C
r
o
s
s
E
n
t
r
o
p
y
(
x
)
=
∑
i
=
1
c
y
i
⋅
log
(
y
^
)
y
i
=
s
o
f
t
m
a
x
(
x
i
)
=
e
x
i
∑
i
=
1
c
e
x
i
CrossEntropy(x) = \sum_{i=1}^cy_i\cdot\log(\hat{y})\quad y_i = softmax(x_i)=\frac{e^{x_i}}{\sum_{i=1}^ce^{x_i}}
CrossEntropy(x)=i=1∑cyi⋅log(y^)yi=softmax(xi)=∑i=1cexiexi
TensorFlow计算多分类交叉熵损失
Tensorflow中的交叉熵损失函数:tf.nn.softmax_cross_entropy_with_logits
在PyTorch中损失函数为:torch.nn.CrossEntropyLoss
对于如下数据:
l
a
b
e
l
s
=
[
0.0
,
0.8
,
0.2
]
labels表示真实输出
y
l
o
g
i
t
s
=
[
0.0
,
5.0
,
1.0
]
logits表示网络的输出
y
^
labels = [0.0, 0.8, 0.2] \quad \text{labels表示真实输出}y\\ logits = [0.0, 5.0, 1.0]\quad \text{logits表示网络的输出}\hat{y}
labels=[0.0,0.8,0.2]labels表示真实输出ylogits=[0.0,5.0,1.0]logits表示网络的输出y^
可以看到真实输出里面的元素大于0,所以需要用softmax处理为概率形式:计算:
s
o
f
t
m
a
x
(
l
o
g
i
t
s
)
=
[
0.00657326
,
0.9755587
,
0.01786798
]
softmax(logits)=[0.00657326, 0.9755587 , 0.01786798]
softmax(logits)=[0.00657326,0.9755587,0.01786798],那么交叉熵损失函数为:
t
f
.
r
e
d
u
c
e
_
s
u
m
(
l
a
b
e
l
s
∗
t
f
.
m
a
t
h
.
l
o
g
(
t
f
.
n
n
.
s
o
f
t
m
a
x
(
l
o
g
i
t
s
)
)
)
=
0.82474494
tf.reduce\_sum(labels*tf.math.log(tf.nn.softmax(logits)))=0.82474494
tf.reduce_sum(labels∗tf.math.log(tf.nn.softmax(logits)))=0.82474494
使用TensorFlow交叉熵损失函数计算:
labels=[0.0,0.8,0.2]
logits=[0.0,5.0,1.0]
-tf.nn.softmax_cross_entropy_with_logits(labels,logits) # 0.82474494
在实际的分类任务中labels通常是one-hot编码的结果:例如[0,1,0](表示输出类别为1)
,那么对于如下的数据:
labels = [[0,1,0]]
logits = [0,5.,1]
使用Tensorflow计算:
import tensorflow as tf
device = tf.config.get_visible_devices()
tf.config.experimental.get_memory_growth(device[1])
labels = tf.constant([[0,1,0]],dtype=tf.float32)
logits = tf.constant([[0,5.,1]])
origin_output = -tf.reduce_sum(labels*tf.math.log(tf.nn.softmax(logits)))
print(origin_output.numpy(),tf.nn.softmax_cross_entropy_with_logits(labels,logits).numpy(),tf.nn.sparse_softmax_cross_entropy_with_logits(tf.argmax(labels,axis=1),logits).numpy())
结果如下tf.nn.sparse_softmax_cross_entropy_with_logits
计算只是将one-hot转换为了索引表示的实际class值):
0.024744948 [0.02474492] [0.02474492]
PyTorch计算多分类交叉熵损失
PyTorch使用了另一种表示方法:
loss
(
x
,
c
l
a
s
s
)
=
−
log
(
exp
(
x
[
c
l
a
s
s
]
)
∑
j
exp
(
x
[
j
]
)
)
=
−
x
[
c
l
a
s
s
]
+
log
(
∑
j
exp
(
x
[
j
]
)
)
\text{loss}(x, class) = -\log\left(\frac{\exp(x[class])}{\sum_j \exp(x[j])}\right)\\ = -x[class] + \log\left(\sum_j \exp(x[j])\right)
loss(x,class)=−log(∑jexp(x[j])exp(x[class]))=−x[class]+log(j∑exp(x[j]))
对于上面的输出
y
^
=
[
0.0
,
5.0
,
1.0
]
\hat{y}=[0.0,5.0,1.0]
y^=[0.0,5.0,1.0] 表示1出现的概率最大,输出为1。即
x
[
c
l
a
s
s
]
=
x
[
1
]
=
l
a
b
e
l
s
[
1
]
=
0.8
x[class]=x[1]=labels[1]=0.8
x[class]=x[1]=labels[1]=0.8计算结果为:
−
0.8
+
l
o
g
∑
i
=
1
3
e
x
[
i
]
=
0.6922
-0.8+log\sum_{i=1}^3e^{x[i]}=0.6922
−0.8+log∑i=13ex[i]=0.6922
原始计算:
labels = torch.tensor([[0.0,0.8,0.2]])
logits = torch.Tensor([[0.0,5.0,1.0]])
loss_value = -labels[0][torch.argmax(logits)]+torch.log(torch.sum(torch.exp(labels),axis=1))
print(loss_value) #
直接计算:
loss = torch.nn.CrossEntropyLoss()
logits = torch.Tensor([[0.0,5.0,1.0]])
labels = torch.Tensor([[0.0,0.8,0.2]])
loss_value = loss(labels,y)
print(loss_value) # 0.6922
对于上述例子:
import torch
one_hot_labels = torch.Tensor([[0,1,0]])
labels = torch.argmax(one_hot_labels,axis=1)
# labels = torch.tensor([1])
logits = torch.Tensor([[0,5,1]])
loss1 = torch.nn.CrossEntropyLoss()(logits,labels)
loss2 = torch.nn.NLLLoss()(torch.nn.LogSoftmax(dim=-1)(logits), labels)
print(loss1,loss2)
完整验证:
import torch
from torch import nn
def logsoftmax(data):
return torch.log(torch.softmax(data,dim=1))
def nlloss(data,labels):
res = 0
for index,label_num in enumerate(labels):
res+=data[index][label_num]
return -res/len(data)
def cross_entropy(data,label):
loss_value = nlloss(logsoftmax(data),labels)
return loss_value
class_num = 5
data = torch.randn(3,class_num,dtype=torch.float)
labels = torch.randint(0,class_num,size=(3,))
t_logsoftmax = nn.LogSoftmax(dim=-1)(data)
m_logsoftmax = logsoftmax(data)
print(t_logsoftmax,m_logsoftmax)
t_nlloss = nn.NLLLoss()(t_logsoftmax,labels)
m_nlloss = nlloss(m_logsoftmax,labels)
print(t_nlloss)
print(m_nlloss)
loss = nn.CrossEntropyLoss()
t_loss = loss(data,labels)
m_loss = cross_entropy(data,labels)
print("交叉熵损失:{:.4f}(PyTorch) {:.4f}(My)".format(t_loss,m_loss))