pytorch多标签分类类别不平衡损失函数

pytorch多标签分类类别不平衡损失函数

focal loss 多标签分类版

def criterion(y_pred, y_true, weight=None, alpha=0.25, gamma=2):
    sigmoid_p = nn.Sigmoid(y_pred)
    zeros = torch.zeros_like(sigmoid_p)
    pos_p_sub = torch.where(y_true > zeros,y_true - sigmoid_p,zeros)
    neg_p_sub = torch.where(y_true > zeros,zeros,sigmoid_p)
    per_entry_cross_ent = -alpha * (pos_p_sub ** gamma) * torch.log(torch.clamp(sigmoid_p,1e-8,1.0))-(1-alpha)*(neg_p_sub ** gamma)*torch.log(torch.clamp(1.0-sigmoid_p,1e-8,1.0))
    return per_entry_cross_ent.sum()

softmax应用于多标签分类

https://mp.weixin.qq.com/s/Ii2sxJUGNvX4CnmtVmbFwA

def criterion2(y_pred,y_true):
    y_pred = (1 - 2*y_true)*y_pred
    y_pred_neg = y_pred - y_true * 1e12
    y_pred_pos = y_pred - (1 - y_true) * 1e12
    zeros = torch.zeros_like(y_pred[...,:1])
    y_pred_neg = torch.cat((y_pred_neg,zeros),dim=-1)
    y_pred_pos = torch.cat((y_pred_pos,zeros),dim=-1)
    neg_loss = torch.logsumexp(y_pred_neg,dim=-1)
    pos_loss = torch.logsumexp(y_pred_pos,dim=-1)
    return torch.mean(neg_loss + pos_loss)
  • 7
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是一个简单的多分类类别平衡损失函数PyTorch代码示例: ```python import torch import torch.nn.functional as F def class_balanced_loss(y_true, y_pred, beta=0.99): # 计算每个类别的样本数量 class_counts = torch.sum(y_true, dim=0) # 计算每个类别的权重 class_weights = (1 - beta) / (1 - torch.pow(beta, class_counts)) # 计算加权的交叉熵损失 weighted_losses = F.binary_cross_entropy_with_logits(input=y_pred, target=y_true, pos_weight=class_weights) loss = torch.mean(weighted_losses) return loss ``` 这个函数的输入参数`y_true`是一个`N x K`的张量,其中`N`是样本数量,`K`是类别数量。每一行表示一个样本的真实标签,用`0`和`1`表示是否属于某个类别。例如,如果有3个样本,4个类别,那么`y_true`可能长这样: ``` [[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 1, 0]] ``` 这个函数的输入参数`y_pred`是一个`N x K`的张量,其中每个元素表示模型对该样本属于该类别的预测得分。例如,如果有3个样本,4个类别,那么`y_pred`可能长这样: ``` [[0.2, 0.8, 0.3, 0.7], [0.9, 0.1, 0.8, 0.2], [0.1, 0.9, 0.6, 0.4]] ``` 这个函数的输入参数`beta`是一个超参数,用于控制类别权重的平滑程度。一般来说,可以将其设置为一个接近于1的值,例如0.99。 使用这个函数的方法非常简单,只需要在训练模型时将损失函数设置为这个类别平衡损失函数即可。例如,使用PyTorch训练一个多分类模型时,可以这样设置损失函数: ```python loss_fn = class_balanced_loss optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) model.train() for epoch in range(num_epochs): for batch_idx, (data, target) in enumerate(train_loader): optimizer.zero_grad() output = model(data) loss = loss_fn(target, output) loss.backward() optimizer.step() ``` 需要注意的是,这个类别平衡损失函数并不是适用于所有情况的通用损失函数,其适用性和效果也与数据集的特点有关。在使用时需要根据实际情况进行调整和优化。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值