Focal Loss原理及实现

1 什么是Focal Loss?

最近工作中,Leader让了解一下Focal Loss,尝试解决信贷场景下样本天然不平衡的问题,于是就开始吭哧吭哧的查资料。

起源

  • Focal Loss的提出是在目标检测的领域之中
  • 目标检测的框架一般分为两种:基于候选区域的two-stage的检测框架(fast r-cnn)基于回归的one-stage的检测框架(yolo)
  • two-stage效果好,但速度慢;one-stage效果一般,但速度快
  • 作者就去探寻为啥one-stage效果一般,最终发现的原因是 正负样本不均衡导致

于是作者就提出了一个牛逼哄哄的办法,使用Focal Loss这种损失函数,来尝试解决这一问题!

2 什么场景下用Focal Loss?

针对样本不平衡的情况下,使用Focal Loss作为损失函数,加强对于hard example的训练!从而一定程度上解决样本不平衡问题!

3 Focal Loss的原理是什么?为什么能解决样本不平衡问题?

Focal Loss核心思想是:整体缩放Loss,易分类样本缩放的比难分类样本更多,从而损失函数中就凸显了难分类样本的权重,使得模型在训练时更专注于难分类的样本。

具体来看下Focal Loss的原理,我们对比的是常见的交叉熵损失函数-binary loss。

3.1 交叉熵损失函数binary loss

交叉熵损失函数见下:
在这里插入图片描述
举例:

  • y’=0.9,易分类样本,属于y=1的样本,那么损失L1=-log0.9,非常接近0的一个正数
  • y’=0.6,难分类样本,无论是y=1还是y=0,损失L2都会相对比较大
  • 最终总的损失函数是将每一个样本对应的损失函数相加,所有样本权重一致。

3.2 Focal Loss的改进

那么Focal Loss改进的直观想法是如何的呢?上面binary loss最终每个样本的权重都是一致的,我们能不能设计一个系数,让易分类样本权重降低,难分类样本权重提高呢?完全可以!

Focal Loss的定义见下图:
在这里插入图片描述
我们来尝试做一个分解,现在这么看有点绕。

  • 首先,正负样本不平衡(y=1样本少),那么直观的想法就是对于两大类样本直接加一个权重,也就是上图中的α
  • 但是α只能解决整体正负样本比的问题,无法解决更核心的问题:希望易分类样本的权重低一些,难分类样本的权重高一些,更加在损失函数中凸显出来!
  • 因此,引入(1-y’)和γ参数。

暂定先取α=0.25,γ=2。这两个参数需要根据具体的数据来进行参数调整。

  • y’=0.9,易分类样本,属于y=1的样本,那么损失L1’=-α(0.1)γ*log0.9,相比原来的的L1,显著降低了很多
  • y’=0.6,难分类样本,无论是y=1还是y=0,损失L2都会相对比较大。L2’=-(1-α)(0.6)γ*log(0.4),相比原来的L2虽然也降低了,但是没有上述L1’降低的那么多!
  • 虽然最终总的损失函数是将每一个样本对应的损失函数相加,但此时所有样本权重并不是一致的了,易分类样本的损失函数显著降低了很多,相当于权重变小难分类样本的损失函数虽然也缩放了,但是缩放降低的比例比易分类样本要小,相当于权重变大了!从而实现了损失函数中更加侧重于难分类样本(hard example)!

4 Focal Loss的实现

4.1 导入库

from sklearn.model_selection import cross_val_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split

import pandas as pd
import numpy as np
from LGB_Model_FL import * # 专属的脚本文件LGB_Model_FL.py
import lightgbm
from Model_Analysis import * # 专属的脚本文件Model_Analysis.py
from IV_Cal import *  # 专属的脚本文件 IV_Cal.py
import os
from datetime import timedelta
from datetime import datetime
from dateutil.relativedelta import relativedelta
from scipy.misc import derivative
import warnings
warnings.filterwarnings('ignore')

# focal loss 损失函数
def focal_loss_lgb_sk(y_true, y_pred, alpha, gamma):
    """
    Focal Loss for lightgbm

    Parameters:
    -----------
    y_pred: numpy.ndarray
        array with the predictions
    dtrain: lightgbm.Dataset
    alpha, gamma: float
        See original paper https://arxiv.org/pdf/1708.02002.pdf
    """
    a,g = alpha, gamma
    def fl(x,t):
        p = 1/(1+np.exp(-x))
        return -( a*t + (1-a)*(1-t) ) * (( 1 - ( t*p + (1-t)*(1-p)) )**g) * ( t*np.log(p)+(1-t)*np.log(1-p) )
    partial_fl = lambda x: fl(x, y_true)
    grad = derivative(partial_fl, y_pred, n=1, dx=1e-6)
    hess = derivative(partial_fl, y_pred, n=2, dx=1e-6)
    return grad, hess

# focal loss 对应的评估函数metric
def focal_loss_lgb_eval_error_sk(y_true, y_pred, alpha, gamma):
    """
    Adapation of the Focal Loss for lightgbm to be used as evaluation loss

    Parameters:
    -----------
    y_pred: numpy.ndarray
        array with the predictions
    dtrain: lightgbm.Dataset
    alpha, gamma: float
        See original paper https://arxiv.org/pdf/1708.02002.pdf
    """
    a,g = alpha, gamma
    p = 1/(1+np.exp(-y_pred))
    loss = -( a*y_true + (1-a)*(1-y_true) ) * (( 1 - ( y_true*p + (1-y_true)*(1-p)) )**g) * ( y_true*np.log(p)+(1-y_true)*np.log(1-p) )
    return 'focal_loss', np.mean(loss), False

def sigmoid(x):
    return 1/(1+np.exp(-x))

4.2 切分数据

df = pd.read_csv('telecom_churn.csv')
df['churn'] = df['churn'].map(str)
churn_dic = {'True':1, 'False':0}
df['churn'] = df['churn'].map(churn_dic)
print(df.shape)
df.head()
(3333, 21)
stateaccount lengtharea codephone numberinternational planvoice mail plannumber vmail messagestotal day minutestotal day callstotal day charge...total eve callstotal eve chargetotal night minutestotal night callstotal night chargetotal intl minutestotal intl callstotal intl chargecustomer service callschurn
0KS128415382-4657noyes25265.111045.07...9916.78244.79111.0110.032.7010
1OH107415371-7191noyes26161.612327.47...10316.62254.410311.4513.733.7010
2NJ137415358-1921nono0243.411441.38...11010.30162.61047.3212.253.2900
3OH84408375-9999yesno0299.47150.90...885.26196.9898.866.671.7820
4OK75415330-6626yesno0166.711328.34...12212.61186.91218.4110.132.7330

5 rows × 21 columns

4.3 分训练集和测试集

# 切分数据
X = df.iloc[:,8:19]
# X = df[['total day calls', 'total night charge', 'number vmail messages', 'total intl charge']]

y = df['churn'].values

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3,
                                                    random_state = 23)
print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)
(2333, 11) (1000, 11) (2333,) (1000,)

4.4 Focal Loss+Lightgbm

# LGB+Focal Loss 其中alpha:为不能让容易分类类别的损失函数太小;gamma:更加关注困难样本 即关注y=1的样本
focal_loss = lambda x,y: focal_loss_lgb_sk(x, y, alpha = 0.25, gamma = 2)

lgb_param = {
    'learning_rate' : 0.01,
    'max_depth':3,
    'n_estimators':300,
    'num_leaves' : 8,
    'subsample':0.7,
    'subsample_freq':3,
    'colsample_bytree':0.7,
    'scale_pos_weight':1,
    'subsample_for_bin':200000,
    'min_split_gain':0,
    'min_child_weight':1e-3,
    'min_child_samples':20,
    'reg_alpha':0,
    'reg_lambda':10,
    'n_jobs':[-1],

    'silent':True,
    'class_weight':None,
    'random_state':None,
    'boosting_type':'gbdt',
    'objective' : focal_loss
    # 'objective' : 'binary_loss',

}

model = LGB_Train_Test(lgb_param, X_train, y_train, X_test, y_test)
Model Accuracy on Train set: 86.4981%
Model Accuracy on Test set: 83.8000%
The KS value of Train set is:
0.4307689687420647

在这里插入图片描述
在这里插入图片描述
The KS value of Test set is:
0.410751642652995
在这里插入图片描述
在这里插入图片描述

5 写在最后

Focal Loss上述只是在一个demo数据集上跑通了,在实际的信贷数据中,Focal loss效果相比binary loss是有所提升的!涉及到公司的数据隐私,就不放图了。

完结撒花!

6 参考资料

  • https://zhuanlan.zhihu.com/p/49981234
  • https://zhuanlan.zhihu.com/p/32423092
  • https://blog.csdn.net/u014380165/article/details/77019084
  • https://blog.csdn.net/qq_34564947/article/details/77200104
  • focal loss论文:https://web.kamihq.com/web/viewer.html?source=extension_pdfhandler&extension_handler=webrequest_1_autoload_true_user_8325679&file=https%3A%2F%2Farxiv.org%2Fpdf%2F1708.02002.pdf
Focal Loss是一种用于解决样本不平衡问题的损失函数。它通过调整样本的权重,使得模型更加关注难以分类的样本。Focal Loss实现可以基于二分类交叉熵。\[1\] 在PyTorch中,可以通过定义一个继承自nn.Module的类来实现Focal Loss。这个类需要定义alpha(平衡因子)、gamma(调整因子)、logits(是否使用logits作为输入)和reduce(是否对损失进行求和)等参数。在forward函数中,根据输入和目标计算二分类交叉熵损失,并根据Focal Loss的公式计算最终的损失。\[1\] 在Keras中,可以通过定义一个自定义的损失函数实现Focal Loss。这个函数需要定义alpha和gamma等参数,并根据Focal Loss的公式计算损失。然后,将这个损失函数作为参数传递给模型的compile函数。\[3\] 总结来说,Focal Loss实现可以基于二分类交叉熵,通过调整样本的权重来解决样本不平衡问题。在PyTorch中,可以定义一个继承自nn.Module的类来实现Focal Loss;在Keras中,可以定义一个自定义的损失函数实现Focal Loss。\[1\]\[3\] #### 引用[.reference_title] - *1* [关于Focal loss损失函数的代码实现](https://blog.csdn.net/Lian_Ge_Blog/article/details/126247720)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^koosearch_v1,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* [Focal Loss原理实现](https://blog.csdn.net/qq_27782503/article/details/109161703)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^koosearch_v1,239^v3^insert_chatgpt"}} ] [.reference_item] - *3* [Focal Loss --- 从直觉到实现](https://blog.csdn.net/Kaiyuan_sjtu/article/details/119194590)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^koosearch_v1,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值