目录
MSE与交叉熵损失:
- 均方差损失函数(MSE): 求一个batch中n个样本的n个输出与期望输出的差的平方的平均值, 常用与回归问题
- 交叉熵损失: 用来评估当前训练得到的概率分布于真实分布的差异情况。 它刻画的是实际输出(概率)与期望输出(概率)的距离,也就是交叉熵的值越小,两个概率分布就越接近。
- 在分类问题中不用MSE,主要有两个原因:① 参数更新缓慢 ②非凸优化问题;
1.BCE Loss
使用BCELoss()函数之前需要加上sigmoid函数:
import torch.nn as nn
sigmoid = nn.Sigmoid()
loss = nn.BCELoss()
final_loss = loss(output, label)
2.top k BCE Loss:
在所有类别中找出前k个error最高的数据,然后拿出来进行求bce loss
def BCE_loss(results, labels, topk=10):
error = torch.abs(labels - torch.sigmoid(results))#one_hot_target
error = error.topk(topk, 1, True, True)[0].contiguous()
target_error = torch.zeros_like(error).float()
error_loss = nn.BCELoss(reduce='mean')(error, target_error)
3.BCE With LogitsLoss
(自带sigmoid)
下面这个代码是输出多个类别,只有一个类别是正例子,对所有类别×相应的权重然后平均或者sum.这个方法可以用于多分类;
loss_f_none_w = nn.BCEWithLogitsLoss(weight=weights, reduction='none')
loss_f_sum = nn.BCEWithLogitsLoss(weight=weights, reduction='sum')
loss_f_mean = nn.BCEWithLogitsLoss(weight=weights, reduction='mean')
# forward
loss_none_w = loss_f_none_w(inputs, target_bce)
loss_sum = loss_f_sum(inputs, target_bce)