CrossEntropyLoss是分类loss的一种,指交叉熵损失函数。下面是基于pytorch的ceLoss的简单实现。
import torch
import torch.nn as nn
import torch.nn.functional as F
#二个分布的差距
#logits shape:[bs,nc]
batchsize = 2
num_class = 4
logits = torch.randn(batchsize,num_class) #未归一化
target = torch.randint(num_class,size=(batchsize,)) #传入的是整型索引
target_logits = torch.randn(batchsize,num_class)
## 调用Cross Entropy loss
ce_loss_fn = nn.CrossEntropyLoss()
ce_loss = ce_loss_fn(logits,target)
ce_loss2 = ce_loss_fn(logits,target_logits)