在使用Pytorch时经常碰见这些函数cross_entropy,CrossEntropyLoss, log_softmax, softmax。看得我头大,所以整理本文以备日后查阅。
首先要知道上面提到的这些函数一部分是来自于torch.nn,而另一部分则来自于torch.nn.functional(常缩写为F)。二者函数的区别可参见 知乎:torch.nn和funtional函数区别是什么?
下面是对与cross entropy有关的函数做的总结:
torch.nn | torch.nn.functional (F) |
---|---|
CrossEntropyLoss | cross_entropy |
LogSoftmax | log_softmax |
NLLLoss | nll_loss |
下面将主要介绍torch.nn.functional中的函数为主,torch.nn中对应的函数其实就是对F里的函数进行包装以便管理变量等操作。
在介绍cross_entropy之前先介绍两个基本函数:
1|0log_softmax
这个很好理解,其实就是log和softmax合并在一起执行。
2|0nll_loss
该函数的全程是negative log likelihood loss,函数表达式为
f(x,class)=−x[class]f(x,class)=−x[class]
例如假设x=[1,2,3],class=2x=[1,2,3],class=2,那额f(x,class)=−x[2]=−3f(x,class)=−x[2]=−3
3|0cross_entropy
交叉熵的计算公式为:
cross_entropy=−∑k=1N(pk∗logqk)cross_entropy=−∑k=1N(pk∗logqk)
其中pp表示真实值,在这个公式中是one-hot形式;qq是预测值,在这里假设已经是经过softmax后的结果了。
仔细观察可以知道,因为pp的元素不是0就是1,而且又是乘法,所以很自然地我们如果知道1所对应的index,那么就不用做其他无意义的运算了。所以在pytorch代码中target不是以one-hot形式表示的,而是直接用scalar表示。所以交叉熵的公式(m表示真实类别)可变形为:
cross_entropy=−∑k=1N(pk∗logqk)=−logqmcross_entropy=−∑k=1N(pk∗logqk)=−logqm
仔细看看,是不是就是等同于log_softmax和nll_loss两个步骤。
所以Pytorch中的F.cross_entropy会自动调用上面介绍的log_softmax和nll_loss来计算交叉熵,其计算方式如下:
loss(x,class)=−log(exp(x[class])∑jexp(x[j]))loss(x,class)=−log(exp(x[class])∑jexp(x[j]))
代码示例
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.randint(5, (3,), dtype=torch.int64)
>>> loss = F.cross_entropy(input, target)
>>> loss.backward()