focal loss in pytorch

5 篇文章 1 订阅
3 篇文章 0 订阅
def multi_label_loss(y_pred, y_true):
	'''
	Zhang, M. L., & Zhou, Z. H. (2006). Multilabel neural networks with applications to functional genomics and text categorization. IEEE transactions on Knowledge and Data Engineering, 18(10), 1338-1351.
	'''
    # 注 0.5 为激活函数的阈值
    y_true = torch.reshape(y_true,(-1,config.categories))
    y_pred = torch.reshape(y_pred,(-1,config.categories))
    
    ml_loss = []
    for i in range(y_true.shape[0]):
        y_true_i, y_pred_i = y_true[i], y_pred[i]
        
        ones_mask = (y_true_i==torch.tensor(1.))
        zero_mask = (y_true_i==torch.tensor(0.))
        #ones_mask, zero_mask = (y_true_i==1), (y_true_i==0)
        ones, zeros = y_pred_i[ones_mask], y_pred_i[zero_mask]

        if ones.shape[0]==0:# 若真实标签没有1,全为 0, 
            '''ones.shape[0]==0, zeros.shape[0]==8,激活函数阈值为0.5,则 neg_pred 小于0.5 即可'''
            ones = config.activ_th * torch.ones(1,requires_grad=True).cuda()
            
        elif zeros.shape[0]==0:
            '''zeros.shape[0]==0, ones.shape[0]==8, 激活函数阈值为0.5,则 pos_pred 大于0.5 即可'''
            zeros = config.activ_th * torch.ones(1,requires_grad=True).cuda()
        #print ('ones, zeros',ones.requires_grad, zeros.requires_grad)
        p_repeat = ones.unsqueeze(1).expand(ones.size()[0],zeros.size()[0]).reshape((-1,1))
        n_repeat = zeros.unsqueeze(0).expand(ones.size()[0],zeros.size()[0]).reshape((-1,1))
        p_n_pairs = torch.cat((p_repeat,n_repeat),1)

        ml_loss_i = torch.exp(-p_n_pairs[:,0]+p_n_pairs[:,1])
        ml_loss_i = torch.div(torch.sum(ml_loss_i), 1.0 * ones.shape[0]*zeros.shape[0])

        ml_loss.append(ml_loss_i.unsqueeze(0))
        
    ml_loss = torch.cat(ml_loss,0)
    ml_loss = torch.mean(ml_loss)
    return ml_loss

def binary_focal_loss(y_pred, y_true, gamma=2., alpha=.25):
    """
	把二分类拓展到多标签
	Lin, Tsung-Yi, et al. "Focal loss for dense object detection." Proceedings of the IEEE international conference on computer vision. 2017.

        Binary form of focal loss.
            FL(p_t) = -alpha * (1 - p_t)**gamma * log(p_t)
            where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively.
        param y_true: A tensor of the same shape as `y_pred`, 1 维
        param y_pred:  A tensor resulting from a sigmoid, 1 维
        return: Output tensor.
    """
    #print ('y_true,y_pred:',y_true.requires_grad,y_pred.requires_grad) False,True
    ones = torch.ones_like(y_pred,dtype = torch.float)
    zeros = torch.zeros_like(y_pred,dtype = torch.float)
    pt_1 = torch.where(y_true == 1, y_pred, ones)
    pt_0 = torch.where(y_true == 0, y_pred, zeros)

    epsilon = eps=1e-10
    pt_1 = torch.clamp(pt_1, epsilon, 1. - epsilon)
    pt_0 = torch.clamp(pt_0, epsilon, 1. - epsilon)

    bi_fl =  -torch.sum(alpha * torch.pow(1. - pt_1, gamma) * torch.log(pt_1)) \
             -torch.sum((1 - alpha) * torch.pow(pt_0, gamma) * torch.log(1. - pt_0))
    # 若返回值为sum,则相当于alpha=n_batchsize,若返回为mean,则相当于alpha=1/n_class
    return bi_fl	
	
def categorical_focal_loss(y_pred,y_true,gamma=2.,weight = None):
    """ 用于单标签
        :param y_true: idx_list [1,2,5,4,0...],shape = [n_batch]
        :param y_pred: A tensor resulting from a softmax,shape = [n_batch,n_class]
        :return: loss of a batch.
        if weight : loss = sum([w1*CE_1,w2*CE_2...]/(w1+w2+..))
    """
    #def categorical_focal_loss(y_pred,y_true,gamma=2.,weight = None):
    # 把 y_true 转换成  [n_batch,n_class]
    #y_pred = torch.tensor(y_pred,requires_grad=True)
    zeros = torch.zeros_like(y_pred,dtype = torch.float)
    index = (torch.LongTensor(list(range(y_true.size()[0]))), y_true)
    y_true = zeros.index_put_(indices = index,
                              values = torch.ones_like(y_true,dtype = torch.float))

    epsilon = eps=1e-10
    y_pred = torch.clamp(y_pred, epsilon, 1. - epsilon)
    # Calculate Cross Entropy
    cross_entropy = -y_true * torch.log(y_pred)
    # Calculate Focal Loss
    loss = torch.pow(1 - y_pred, gamma) * cross_entropy
    if weight is not None:
        loss = weight * loss
    # Sum the losses in mini_batch(or K.mean)
    return torch.mean(loss)

序列预测时的binary focal loss

2021.02 补充

def binary_focal_loss(y_pred, y_true,seq_len = None, gamma=2., alpha=.25):
    """
        Binary form of focal loss.
            FL(p_t) = -alpha * (1 - p_t)**gamma * log(p_t)
            where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively.
        param y_true: A tensor of the same shape as `y_pred`, 1 维
        param y_pred:  A tensor resulting from a sigmoid, 1 维
        return: Output tensor.
    """
    if seq_len is not None:

        # 输出是 pad 过后的的序列标签预测,则取出非pad部分的预测
        # y_pred, y_true: shape = [batch_size, n_cxt(pad), n_class]
        # seq_len: shape = [batch_size,]
        n_sents_unpad = torch.sum(seq_len)

        y_pred_unpad = torch.zeros((0,y_true.size()[-1]),dtype = torch.float).to(device)
        y_true_unpad = torch.zeros((0,y_true.size()[-1]),dtype = torch.float).to(device)
        for i in range(y_true.size(0)):
            cur_n_cxt = seq_len[i]
            y_pred_unpad = torch.cat((y_pred_unpad,y_pred[i][:cur_n_cxt]),dim = 0)
            y_true_unpad = torch.cat((y_true_unpad,y_true[i][:cur_n_cxt]),dim = 0)
        # print (n_sents_unpad, y_pred_unpad.shape,y_pred_unpad.shape, y_true_unpad.shape )
        assert n_sents_unpad == y_pred_unpad.size(0), 'Number Unmatching of collected unpad sents.'
        y_pred, y_true =y_pred_unpad, y_true_unpad
        

    #print ('y_true,y_pred:',y_true.requires_grad,y_pred.requires_grad) False,True
    ones  = torch.ones_like(y_pred,dtype = torch.float)
    zeros = torch.zeros_like(y_pred,dtype = torch.float)
    pt_1  = torch.where(y_true == 1, y_pred, ones)
    pt_0  = torch.where(y_true == 0, y_pred, zeros)

    epsilon = eps=1e-10
    pt_1 = torch.clamp(pt_1, epsilon, 1. - epsilon)
    pt_0 = torch.clamp(pt_0, epsilon, 1. - epsilon)

    bi_fl =  -torch.sum(alpha * torch.pow(1. - pt_1, gamma) * torch.log(pt_1)) \
             -torch.sum((1 - alpha) * torch.pow(pt_0, gamma) * torch.log(1. - pt_0))
    # 若返回值为sum,则相当于alpha=n_batchsize,若返回为mean,则相当于alpha=1/n_class
    return bi_fl
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值