pytorch各种交叉熵函数的汇总具体使用

一、引言

  最近被pytorch的各种交叉熵损失弄迷糊了,现在我所知道的交叉熵损失有:

torch.nn.CrossEntropyLoss()

torch.nn.BCELoss()

torch.nn.BCEWithLogitsLoss()

torch.nn.functional.cross_entropy()

torch.nn.functional.binary_cross_entropy()

torch.nn.functional.binary_cross_entropy_with_logits()

torch.nn.NLLLoss()

最麻烦的就是这些函数都是交叉熵损失,但是输入网络predict特征需不需要做softmax,sigmod之类的计算,是不是只输入网络输出的特征即可?输入label需不需要onehot编码,一旦不对,损失就不是想要的损失,效果自然也会不好。下面来分析下:

二、具体分析

总结:损失函数名字中带了with_logits. 而这里的logits指的是,该损失函数已经内部自带了计算logit的操作,无需在传入给这个loss函数之前手动使用sigmoid/softmax将之前网络的输入映射到[0,1]之间

1、torch.nn.CrossEntropyLoss()

  输入不需要做softmax等操作,函数内部会做softmax操作,label不需要进行oneshot编码,如下代码所示:内部计算步骤:进行softmax操作,然后吧label进行oneshot编码,再进行交叉熵计算,求值返回。

输入维度(batch_size, feature_dim)

输出维度  (batch_size, 1)

X_input = torch.tensor[ [2.8883, 0.1760, 1.0774],

          [1.1216, -0.0562, 0.0660],

          [-1.3939, -0.0967, 0.5853]]

y_target = torch.tensor([1,2,0])

loss_func = nn.CrossEntropyLoss()

loss = loss_func(X_input, y_target)

2、torch.nn.BCELoss()

  输入需要做sigmod计算,bce是做的是sigmod计算。label需要编码,BCE损失即可以用作多标签分类,也可以用做多分类。如第三个代码就是多标签分类,第二个代码也就是下面这个代码是多分类。

import torch
import torch.nn as nn

m = nn.Sigmoid()

loss = nn.BCELoss(size_average=False, reduce=False)
input = torch.randn(3, requires_grad=True)
target = torch.empty(3).random_(2)
lossinput = m(input)
output = loss(lossinput, target)

3、torch.nn.BCEWithLogitsLoss()

    带有Logits的函数,需要oneshot编码,输入predict是不需要sigmod计算的。

import torch
import numpy as np

pred = np.array([[-0.4089, -1.2471, 0.5907],
                [-0.4897, -0.8267, -0.7349],
                [0.5241, -0.1246, -0.4751]])
label = np.array([[0, 1, 1],
                  [0, 0, 1],
                  [1, 0, 1]])

pred = torch.from_numpy(pred).float()
label = torch.from_numpy(label).float()

## 通过BCEWithLogitsLoss直接计算输入值(pick)
crition1 = torch.nn.BCEWithLogitsLoss()
loss1 = crition1(pred, label)
print(loss1)

4、torch.nn.functional.cross_entropy()

不需要进行softmax计算,不需要进行oneshot编码,与第一个等价

 

import torch
import torch.nn.functional as F
input = torch.randn(3, 5, requires_grad=True) 
target = torch.randint(5, (3,), dtype=torch.int64)输出:tensor([0, 3, 3])
loss = F.cross_entropy(input, target)
loss.backward()

5、torch.nn.functional.binary_cross_entropy() 

为二值交叉熵损失,输入需要进行sigmod计算,需要oneshot编码,与第二个等价

 

6、torch.nn.functional.binary_cross_entropy_with_logits()

和第三个等价,带有Logits的函数,需要oneshot编码,输入predict是不需要sigmod计算的。

input = torch.randn(3, requires_grad=True)
target = torch.empty(3).random_(2)  输出:tensor([0., 1., 0.])

loss = F.binary_cross_entropy_with_logits(input, target)
loss.backward()

 7nn.NLLLoss()

 

m = nn.LogSoftmax(dim=1)
loss = nn.NLLLoss()
# input is of size N x C = 3 x 5
input = torch.randn(3, 5, requires_grad=True)
# each element in target has to have 0 <= value < C
target = torch.tensor([1, 0, 4])
output = loss(m(input), target)
output.backward()

 

  • 10
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值