原文:https://zhuanlan.zhihu.com/p/28527749
import torch
gamma = torch.ones_like(focal_weight).cuda()
gamma[focal_weight > 0.5] = 0.4
gamma[focal_weight < 0.5] = 2.2
focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
目标检测不行:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class FocalLoss(nn.Module):
def __init__(self, gamma=2, alpha=0.25):
super(FocalLoss, self).__