F.cross_entropy和F.binary_cross_entropy_with_logits

F.cross_entropy 函数对应的类是torch.nn.CrossEntropyLoss,在使用时会自动添加logsoftmax然后计算loss(其实就是nn.LogSoftmax() 和nn.NLLLoss() 类的融合)
该函数用于计算多分类问题的交叉熵loss

函数形式:
在这里插入图片描述
这种形式更好理解

在这里插入图片描述

C为class的数目
input 1维情况x[N, C] n维度情况[N, c, d1, d2, d3…]
target 1维度情况[N] n维度情况[N, C, d1, d2, d3] 这的数值范围必须是0-C-1
weight [C] 代表每个类别的权重
注意这里的target必须是long类型

这里的weight是给每一个类别一个weight
一般情况(default)下都会将所有元素的loss进行平均,所以最后loss的公式时为
在这里插入图片描述
Wn表示第n个样本的对应gt类别的权重, Pn是模型预测的对应gt类别的概率。分子上将所有样本的loss相加,然后除以所有样本权重的和。

import numpy as np
import torch
from torch.nn import functional as F
x = np.array([[1, 2],
              [1, 2],
              [1, 1]]).astype(np.float32)

y = np.array([1, 1, 0])
weight = np.array([2, 1])
x = torch.from_numpy(x)
print(F.softmax(x, dim=1))
y = torch.from_numpy(y).long()
weight = torch.from_numpy(weight).float()
loss = F.cross_entropy(x, y, weight=weight)
print(loss)

手动计算
在这里插入图片描述


F.binary_cross_entropy_with_logits()对应的类是torch.nn.BCEWithLogitsLoss,在使用时会自动添加sigmoid,然后计算loss。(其实就是nn.sigmoid和nn.BCELoss的合体)

该函数用于计算多分类问题的交叉熵loss
这里
input [N, *]
target[N, *]
这里的target必须是float类型

这里直接那nn.BCELoss举例子,注意使用nn.BCELoss的时候必须保证输入的是经过sigmoid激活的概率值

这里的weight是给每一个点一个weight

import torch
from torch import nn
def binary_cross_entropyloss(prob, target, weight=None):
    loss = -weight * (target * torch.log(prob) + (1 - target) * (torch.log(1 - prob)))
    loss = torch.sum(loss) / torch.numel(target)
    return loss

label = torch.tensor([
    [1., 0],
    [1., 0],

])
predict = torch.tensor([
    [0.1, 0.3],
    [0.2, 0.8]
])

weight1 = torch.tensor([
    [1., 2],
    [1., 1.],
])

loss1 = nn.BCELoss(weight=weight1)
l1 = loss1(predict, label)
loss = binary_cross_entropyloss(predict, label, weight=weight1)
print(l1, loss)

在这里插入图片描述

  • 19
    点赞
  • 42
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值