CatBoost中目标变量统计

CatBoost中的目标变量统计(Target Statistics)是其处理分类特征(Categorical Features)的核心技术之一。目标变量统计是一种特殊的编码方法,通过利用目标值信息生成数值特征,从而替代传统的独热编码或其他处理方法。这种方法对于具有高基数分类特征(如ID或字符串)特别高效。


目标变量统计的原理

目标变量统计编码的关键思想是用分类特征的历史统计信息来替代原始的类别值。例如,对于分类特征 A A A,其编码可以表示为该特征下目标变量 y y y 的均值、加权均值或其他统计量。

在CatBoost中,目标变量统计的计算方式遵循以下原则:

  1. 避免数据泄漏

    • 目标变量的统计值只能使用当前样本之前的数据计算,确保训练过程中不会泄漏测试数据的目标值。
    • 这通过Ordered Target Statistics来实现。
  2. 动态计算方式

    • 每一行的目标统计值根据之前数据的目标值动态计算,而不是直接使用整个数据集的统计值。
  3. 平滑处理

    • 为避免分类特征类别样本量较小时统计值不稳定,CatBoost对统计结果进行了平滑处理。
    • 一种常见的平滑方式是将类别均值与全局目标均值按权重组合。

计算目标变量统计的过程

1. 公式表达

目标变量统计编码通常采用以下形式计算:

S ( A i ) = ∑ j = 1 i − 1 y j + α ⋅ μ N i − 1 + α S(A_i) = \frac{\sum_{j=1}^{i-1} y_j + \alpha \cdot \mu}{N_{i-1} + \alpha} S(Ai)=Ni1+αj=1i1yj+αμ

其中:

  • A i A_i Ai 是样本 i i i 的分类特征值;
  • y j y_j yj 是样本 j j j 的目标值;
  • μ \mu μ 是目标变量的全局均值;
  • N i − 1 N_{i-1} Ni1 是类别 A i A_i Ai 在样本 1 1 1 i − 1 i-1 i1 中的出现次数;
  • α \alpha α 是平滑参数(控制全局均值对结果的影响)。
2. 分布处理

目标统计值通过逐行处理的方式计算,确保样本 i i i 的值不会用到样本 i i i 本身的目标值,从而避免信息泄漏。

3. 分组计算

对于训练数据集,CatBoost在内部根据数据顺序分组,先计算每组的目标统计,再将这些统计结果应用于模型训练。


Ordered Target Statistics的独特性

CatBoost的“Ordered Target Statistics”相较于其他目标编码方法的主要不同在于:

  1. 动态顺序计算
    • 按照训练数据的时间顺序逐步更新,确保每个样本的目标统计值基于其之前样本计算。
  2. 无信息泄漏
    • 避免了传统目标编码中使用目标变量的整体统计值而导致的未来信息泄漏问题。

具体案例

假设有一个数据集如下:

样本ID分类特征(City)目标变量(点击率)
1New York1
2Los Angeles0
3New York1
4Los Angeles1
5New York0

目标变量统计编码的过程如下:

  1. 第一行:
    • 对于 C i t y = New York City = \text{New York} City=New York,没有历史数据,目标统计值使用初始全局均值 μ \mu μ
  2. 第二行:
    • 对于 C i t y = Los Angeles City = \text{Los Angeles} City=Los Angeles,同样使用全局均值。
  3. 第三行:
    • 对于 C i t y = New York City = \text{New York} City=New York,基于前两行计算:
      S ( New York ) = 1 1 = 1 S(\text{New York}) = \frac{1}{1} = 1 S(New York)=11=1
  4. 以此类推。

目标变量统计的优点

  1. 对高基数分类特征有效

    • 比如用户ID、商品ID,这些特征类别非常多,传统方法(如独热编码)会导致高维稀疏矩阵,而目标统计可以生成紧凑的数值特征。
  2. 避免信息泄漏

    • Ordered Statistics的顺序计算确保每个样本的特征值与目标变量是独立的。
  3. 对模型性能提升显著

    • 目标变量统计利用了目标变量的潜在分布信息,可以提升模型预测精度。

代码实现示例

from catboost import CatBoostClassifier, Pool

# 示例数据
data = {
    'City': ['New York', 'Los Angeles', 'New York', 'Los Angeles', 'New York'],
    'Clicked': [1, 0, 1, 1, 0]
}

# 数据池
train_data = Pool(data=data['City'], label=data['Clicked'], cat_features=[0])

# 初始化模型
model = CatBoostClassifier(iterations=10, depth=2, learning_rate=0.1)

# 训练模型
model.fit(train_data)

# 查看目标变量统计
print(model.get_feature_importance(prettified=True))

此代码中,CatBoost会自动对City特征进行目标统计编码,无需用户显式指定。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

彬彬侠

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值