前言:
一直想写损失函数的技术总结,但网上已经有诸多关于损失函数综述的文章或博客,考虑到这点就一直拖着没写,直到有一天,我将一个二分类项目修改为多分类,简简单单地修改了损失函数,结果一直有问题,后来才发现是不同函数的标签的设置方式并不相同。
为了避免读者也出现这样的问题,本文中会给出每个损失函数的pytorch使用示例,这也是本文与其它相关综述文章或博客的区别所在。希望读者在阅读本文时,重点关注一下每个损失函数的使用示例中的target的设置问题。
本文对损失函数的类别和应用场景,常见的损失函数,常见损失函数的表达式,特性,应用场景和使用示例作了详细的总结。
主要涉及到L1 loss、L2 loss、Negative Log-Likelihood loss、Cross-Entropy loss、Hinge Embedding loss、Margin Ranking Loss、Triplet Margin loss、KL Divergence.
损失函数分类与应用场景
损失函数可以分为三类:回归损失函数(Regression loss)、分类损失函数(Classification loss)和排序损失函数(Ranking loss)。
应用场景:
回归损失:用于预测连续的值。如预测房价、年龄等。
分类损失:用于预测离散的值。如图像分类,语义分割等。
排序损失:用于预测输入数据之间的相对距离。如行人重识别。
L1 loss
也称Mean Absolute Error,简称MAE,计算实际值和预测值之间的绝对差之和的平均值。
表达式如下:
Loss( pred , y ) = | y - pred |
y表示标签,pred表示预测值。
应用场合:回归问题。
根据损失函数的表达式很容易了解它的特性:当目标变量的分布具有异常值时,即与平均值相差很大的值,它被认为对异常值具有很好的鲁棒行。
使用示例:
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)
mae_loss = torch.nn.L1Loss()
output = mae_loss(input, target)
L2 loss
也称为Mean Squared Error,简称MSE,计算实际值和预测值之间的平方差的平均值。
表达式如下:
应用场合:对大部分回归问题,pytorch默认使用L2,即MSE。
使用平方意味着当预测值离目标值更远时在平方后具有更大的惩罚,预测值离目标值更近时在平方后惩罚更小,因此,当异常值与样本平均值相差格外大时,模型会因为惩罚更大而开始偏离,相比之下,L1对异常值的鲁棒性更好。
使用示例:
input = torch