引言
本文旨在对pytorch中常用于分类问题的损失函数BinaryCrossEntropy(), CrossEntropy()用法进行一个简要的介绍。常见的文章主要是对这些损失函数的原理进行了数学推导,而本文主要介绍了其输入输出的shape和格式要求,作为一个工具存在。
损失函数
本文涉及到的损失函数有BCELoss()、BCEWithLogitsLoss()、NLLLOSS()、CrossEntropyLoss(),前两者是二分类问题常用的损失函数,后两者是多分类问题常用的损失函数。列出格式表如下:
输入格式 | label的dtype | 是否为独热向量 | 网络输出是否需要激活 | |
---|---|---|---|---|
BCELoss() | (pred:[*],label:[*],二者相同即可) | torch.float32 | 否 | 是 |
BCEWithLogitsLoss() | (pred:[*],label:[*],二者相同即可) | torch.float32 | 否 | 否 |
NLLLOSS() | (pred:[N,C],label:[N,]) | torch.int64 | 否 | 是 |
CrossEntropyLoss() | (pred:[N,C],label:[N,])或 (pred:[N,C],label:[N,C]) | torch.int64/torch.float32,torch.float64 | 否 | 否 |
其中CrossEntropyLoss之所以会有label为[N,C]形状却不并不为onehot向量,这是因为这里的label描述的是一个样本属于多个类别的情况,可以认为是属于每一种类别的可能性,也可以认为是软化的onehot向量。
计算方式
(默认在batch上采用平均):
BCELoss()
l o s s = − 1 N ∑ i N [ y i ⋅ l o g ( p i ) + ( 1 − y i ) ⋅ l o g ( 1 − p i ) ] loss=-\frac{1}{N}\sum_i^{N}[y_i\cdot log(p_i)+ (1-y_i)\cdot log(1-p_i)] loss=−N1∑iN[yi⋅log(pi)+(1−yi)⋅log(1−pi)],其中 y i y_i yi为实际标签, p i p_i pi为网路预测其属于正样本的输出值(非概率)。
BCEWithLogitsLoss()
l o s s = − 1 N ∑ i N ( 1 − y i ) ⋅ l o g ( 1 − p i ) loss=-\frac{1}{N}\sum_i^{N}(1-y_i)\cdot log(1-p_i) loss=−N1∑iN(1−yi)⋅log(1−pi),其中 y i y_i yi为实际标签, p i p_i pi为网路预测其属于正样本的概率。
NLL loss
l o s s = − 1 N ∑ i N ∑ j C I ( y i = c ) p i c loss=-\frac{1}{N}\sum_i^{N} \sum_j^CI (y_{i}=c)p_{ic} loss=−N1∑iN∑jCI(yi=c)pic,其中 I I I为指示函数,当第 i i i个样本的标签 y i y_{i} yi与当前类别c相同时取1,否则取0; p i c p_{ic} pic为网络输出的第 i i i个样本属于第 c c c类的概率。
CrossEntropyLoss()
l o s s = − 1 N ∑ i N ∑ j c l o g e x p ( p c ) ∑ j c e x p ( p c ) I ( y i = c ) loss=-\frac{1}{N}\sum_i^{N}\sum_j^clog\frac{exp(p_c)}{\sum_j^cexp(p_c)}I (y_{i}=c) loss=−N1∑iN∑jclog∑jcexp(pc)exp(pc)I(yi=c),其中 I I I为指示函数,当第 i i i个样本的标签 y i y_{i} yi与当前类别c相同时取1,否则取0; p i c p_{ic} pic为网络输出的第 i i i个样本属于第 c c c类的输出值(非概率)。