torch自己实现交叉熵损失函数
torch实现交叉熵
import torch
import torch.nn as nn
import torch.nn.functional as F
class Loss(torch.nn.Module):
def __init__(self, reduction='mean'):
super(Loss, self).__init__()
self.reduction = reduction
def forward(self, logits, target): # [bs,num_class] CE=q*-log(p), q*log(1-p),p=softmax(logits)
target = target.reshape(logits.shape