loss盘点: BCE loss —— binary_cross_entropy_with_logits

我的 torch 版本: 1.8.1+cu111
我的 paddle 版本: 2.4.1

torch API位置

torch.nn.functional.binary_cross_entropy_with_logits

paddle API位置

paddle.nn.functional.binary_cross_entropy_with_logits

二者除了 Deprecated 参数外,在大部分计算上基本都是对齐的

1. 计算方式

logit 是模型的输出,通过sigmoid激活函数 ( σ \sigma σ) 之后便可以转化为概率

l o s s loss loss 是这样计算的:
O u t = − L a b e l s ∗ log ⁡ ( σ ( L o g i t ) ) − ( 1 − L a b e l s ) ∗ log ⁡ ( 1 − σ ( L o g i t ) ) Out = -Labels * \log(\sigma(Logit)) - (1 - Labels) * \log(1 - \sigma(Logit)) Out=Labelslog(σ(Logit))(1Labels)log(1σ(Logit))
其实也就是,交叉熵 l o s s loss loss 的最基本公式:
O u t = − Y ∗ l o g ( y p r e d ) − ( 1 − Y ) ∗ l o g ( 1 − y p r e d ) Out = -Y * log(y_{pred}) - (1 - Y) * log(1 - y_{pred}) Out=Ylog(ypred)(1Y)log(1ypred)

σ ( L o g i t ) = 1 1 + e − L o g i t \sigma(Logit) = \frac{1}{1 + e^{-Logit}} σ(Logit)=1+eLogit1 带入可以简化计算,则:
O u t = L o g i t − L o g i t ∗ L a b e l s + log ⁡ ( 1 + e − L o g i t ) Out = Logit - Logit * Labels + \log(1 + e^{-Logit}) Out=LogitLogitLabels+log(1+eLogit)

文档上这样说:

该 OP 结合了 sigmoid 操作和 BCELoss 操作。同时,我们也可以认为该 OP 是sigmoid_cross_entrop_with_logits 和一些 reduce 操作的组合。

2. 实验代码

# -*- coding: utf-8 -*-
"""
Created on Wed Jan  4 22:36:50 2023

@author: zihao
"""

import numpy as np
import torch
import paddle


# ----------- numpy 参数 -----------
np.random.seed(1107)

# 假设 bs=4, 7种(多分类)
np_logit = np.random.rand(4, 7).astype("float32") 
np_target = np.random.randint(2, size=(4, 7)).astype("float32")

# 给每个类加权重
np_pos_weight = np.random.randint(2, 4, size=(7,)).astype("float32")

# 给每个 batch 的元素 加权重
np_weight = np.random.randint(2, 4, size=(7,)).astype("float32")


# ----------- torch -----------
t_logit = torch.tensor(np_logit)
t_target = torch.tensor(np_target)
t_pos_weight = torch.tensor(np_pos_weight)
t_weight = torch.tensor(np_weight)
t_out = torch.nn.functional.binary_cross_entropy_with_logits(t_logit, t_target,
                                                             weight=t_weight,
                                                             pos_weight=t_pos_weight,
                                                             reduction='none')

# torch 手动计算
t_out_hand = t_logit - t_logit * t_target + torch.log(1 + torch.exp(-t_logit))
t_pos_weight = t_target * t_pos_weight + (1 - t_target)
t_out_hand = t_out_hand * t_pos_weight 
t_out_hand = t_out_hand * t_weight 


# ----------- paddle -----------
p_logit = paddle.to_tensor(np_logit)
p_target = paddle.to_tensor(np_target)
p_pos_weight = paddle.to_tensor(np_pos_weight)
p_weight = paddle.to_tensor(np_weight)
p_out = paddle.nn.functional.binary_cross_entropy_with_logits(p_logit, p_target, 
                                                              weight=p_weight,
                                                              pos_weight=p_pos_weight,
                                                              reduction='none')

# paddle  手动计算
p_out_hand = p_logit - p_logit * p_target + paddle.log(1 + paddle.exp(-p_logit))
p_pos_weight = p_target * p_pos_weight + (1 - p_target)
p_out_hand = p_out_hand * p_pos_weight 
p_out_hand = p_out_hand * p_weight 

在以上代码中,t_out 和 t_out_hand 近乎相等,p_out 和 p_out_hand 近乎相等。前者是调用API计算的,后者是根据公式手动计算的

3. 稍稍看下源码

在 Paddle 源码新动态图部分是这样计算的:

    if in_dygraph_mode():
        one = _C_ops.full(
            [1],
            float(1.0),
            core.VarDesc.VarType.FP32,
            _current_expected_place(),
        )
		
		# 此处按照公式进行计算
        out = _C_ops.sigmoid_cross_entropy_with_logits(
            logit, label, False, -100
        )
		
		# 给每个正例乘以对应的权重 pos_weight 
        if pos_weight is not None:
            log_weight = _C_ops.add(
                _C_ops.multiply(label, _C_ops.subtract(pos_weight, one)), one
            )
            out = _C_ops.multiply(out, log_weight)
		
		# 给每个 batch 乘以对应权重
        if weight is not None:
            out = _C_ops.multiply(out, weight)
		
		# 做 reduce 操作
        if reduction == "sum":
            return _C_ops.sum(out, [], None, False)
        elif reduction == "mean":
            return _C_ops.mean_all(out)
        else:
            return out

关于 pos_weight 的计算,诸位需要稍微认真看一下

我是这样计算的:

p_pos_weight = p_target * p_pos_weight + (1 - p_target)

可以这样简化一下:

p_pos_weight = p_pos_weight * p_target - one * p_target + one
             = (p_pos_weight - one) * p_target + one
             = p_target * (p_pos_weight - one) + one

也就是源码中这样的计算

log_weight = _C_ops.add(
                _C_ops.multiply(label, _C_ops.subtract(pos_weight, one)), one
            )
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值