1 定义对称交叉熵损失函数
import torch
import torch.nn.functional as F
import numpy as np
eps = 1e-7
class SCELoss(torch.nn.Module):
def __init__(self, alpha, beta, num_classes=10):
super(SCELoss, self).__init__()
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.alpha = alpha
self.beta = beta
self.num_classes = num_classes
self.cross_entropy = torch.nn.CrossEntropyLoss()
def forward(self, pred, labels):
# CCE
ce = self.cross_entropy(pred, labels.long()) #zzy修改.long()
# RCE
pred = F.softmax(pred, dim=1) #softmax函数将原始分数转换为标准化的概率分布,使得概率之和为1
pred = torch.clamp(pred, min=1e-7, max=1.0) #clamp将张量pred中的所有元素裁剪到指定的取值范围内
label_one_hot = torch.nn.functional.one_hot(labels.to(torch.int64), self.num_classes).float().to(self.device) #zzy修改to(torch.int64)
label_one_hot = torch.clamp(label_one_hot, min=1e-4, max=1.0)
rce = (-1*torch.sum(pred * torch.log(label_one_hot), dim=1))
# Loss
loss = self.alpha * ce + self.beta * rce.mean()
return loss
2 调用交叉熵损失函数
if loss_function =='SCE':
criterion = SCELoss(alpha=0.1, beta=1.0, num_classes=10)
loss = criterion(outputs, labels)
3 参考文章
Robust Federated Learning with Noisy and Heterogeneous Clients(cvpr 2022)