import torch
import torch.nn as nn
# Define logits and targets
logits = torch.tensor([[0.5], [-1.0], [2.0]], dtype=torch.float32)
targets = torch.tensor([[1.0], [0.0], [1.0]], dtype=torch.float32)
# Define the loss function
loss_fn = nn.BCEWithLogitsLoss(reduction='sum')
# Compute the loss
loss = loss_fn(logits, targets)
print(loss.item())
# 0.9142667055130005