BCELoss,BCEWithLogitsLoss和CrossEntropyLoss

目录

二分类

1. BCELoss

2. BCEWithLogitsLoss

多分类

1. CrossEntropyLoss

 举例


二分类

两个损失:BCELoss,BCEWithLogitsLoss

1. BCELoss

输入:([B,C], [B,C]),代表(prediction,target)的维度,其中,B是Batchsize,C为样本的class,即样本的类别数。

输出:一个标量

等价于:BCELoss + sigmoid

import torch
from torch import nn

input = torch.randn(3) # (3,1) 随机生成一个输入,没有被sigmoid。
print(input)
print(input.shape)
target=torch.Tensor([0., 1., 1.])
loss1=nn.BCELoss()
print("BCELoss:",loss1(torch.sigmoid(input), target))#需要sigmod


输出:
BCELoss: tensor(1.0053)


2. BCEWithLogitsLoss

输入:([B,C], [B,C]),输出:一个标量

import torch
from torch import nn

input = torch.randn(3) # (3,1) 随机生成一个输入,没有被sigmoid。
print(input)
print(input.shape)
target=torch.Tensor([0., 1., 1.])
loss2=nn.BCEWithLogitsLoss()
print("BCEWithLogitsLoss:",loss2(input,target))#不需要sigmoid


输出:
BCEWithLogitsLoss: tensor(1.0053)

多分类

1. CrossEntropyLoss

输入:([B,C], [B]) 输出:一个标量(这个minibatch的mean/sum的loss)

nn.CrossEntropyLoss计算过程: 
input: logits(未经过softmax的模型的"输出”)

  •  softmax(input)
  • -log(softmax(input))
  • 用target做选择提取(关于logsoftmax)· mean

等价于:nn.CrossEntropyLoss = nn.NLLLoss(nn.LogSoftmax)
 

import torch
from torch import nn

loss2 = nn.CrossEntropyLoss(reduction="none")
target2 = torch.tensor([0, 1, 2])
predict2 = torch.tensor([[0.9, 0.2, 0.8], [0.5, 0.2, 0.4], [0.4, 0.2, 0.9]])
print(predict2.shape) # torch.Size([3, 3])
print(target2.shape) # torch.Size([3])
print(loss2(predict2, target2))

# #结果计算为:
# tensor([0.8761, 1.2729, 0.7434])

 举例

1. BCEWithLogitsLoss计算ACC和Loss:

参考:https://github.com/Loche2/IMDB_RNN/blob/master/training.py

criterion = nn.BCEWithLogitsLoss()
# 计算准确率
def binary_accuracy(predicts, y):
    rounded_predicts = torch.round(torch.sigmoid(predicts))
    correct = (rounded_predicts == y).float()
    accuracy = correct.sum() / len(correct)
    return accuracy


# 训练
def train(model, iterator, optimizer, criterion):
    model.train()
    epoch_loss = 0
    epoch_accuracy = 0
    for batch in tqdm(iterator, desc=f'Epoch [{epoch + 1}/{EPOCHS}]', delay=0.1):
        optimizer.zero_grad()
        predictions = model(batch.text[0]).squeeze(1)
        loss = criterion(predictions, batch.label)
        accuracy = binary_accuracy(predictions, batch.label)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_accuracy += accuracy.item()
    return epoch_loss / len(iterator), epoch_accuracy / len(iterator)

2. 计算ACC和Loss

# 截取情感分析部分代码 
    criterion = nn.CrossEntropyLoss()
    total_loss = 0.0
    correct_predictions = 0
    total_predictions = 0
   for batch in train_loader:
        input_ids = batch['input_ids'].to(device)
        labels = batch['label'].to(device)

        optimizer.zero_grad()

        logits = model(input_ids)
        loss_sentiment = criterion(logits, labels.long())
       
        loss_sentiment.backward()
        optimizer.step()

        total_loss += loss_sentiment.item()

        # get sentiment accuracy
        predicted_labels = torch.argmax(logits, dim=1)
        correct_predictions += torch.sum(predicted_labels == labels).item()
        total_predictions += labels.size(0)

    accuracy = correct_predictions / total_predictions
    loss = total_loss / len(train_loader)

也可以直接看github上别人写的例子:https://github.com/songyouwei/ABSA-PyTorch/blob/master/train.py

参考:

深刻剖析与实战BCELoss详解(主)和BCEWithLogitsLoss(次)以及与普通CrossEntropyLoss的区别(次)-CSDN博客

另外提出一个问题:

二分类必须用BCEWithLogitsLoss吗,也可以用CrossEntropyLoss吧?

(1)如果用CrossEntropyLoss的话,只要让网络的fc层为nn.Linear(hidden_size, 2)就行,这样就和多分类一样算。另外CrossEntropyLoss里面包含了softmax,所以在计算loss的时候也不需要过softmax再算loss.

(2)如果用BCEWithLogitsLoss的话,就按照上面举例中BCEWithLogitsLoss计算Loss,只是如上面代码可是,再计算Acc的时候将predict使用sigimoid缩放到0,1来计算预测正确的个数

注:仅供学习记录,理解或者学习有误请与我联系

参考问题:二分类问题,应该选择sigmoid还是softmax? - 知乎 

  • 7
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值