pytorch损失函数nn.BCELoss中weight源码解读
在处理分类问题时遇到最多的问题就是数据不平衡,二分类中出现的不平衡多为正负样本的不平衡。在此类问题时,最好的方法是重新生成平衡的数据,若面对本生就不平衡的数据集,如异常检测等,常采用以下两种方法:
- 采样:数目多的类别进行欠采样,数目少的类别进行过采样
- 权重:调整不同样本的权重,强制网络关注小样本类别
nn.BCELoss
网上看了很多讲解BCELoss中weight参数的例子,但是感觉还是不太清晰,索性进行源码解读,看weight参数背后的逻辑到底是怎么运行的。
class BCELoss(_WeightedLoss):
r"""
Args:
weight (Tensor, optional): a manual rescaling weight given to the loss
of each batch element. If given, has to be a Tensor of size `nbatch`.
size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
the losses are averaged over each loss element in the batch. Note that for
some losses, there are multiple elements per sample. If the field :attr:`size_average`
is set to ``False``,