DESCN:用于个体治疗效果估计的深度全空间交叉网络

英文题目:DESCN: Deep Entire Space Cross Networks for Individual Treatment Effect Estimation

翻译:用于个体治疗效果估计的深度全空间交叉网络

单位:阿里

作者介绍: About | Kailiang Zhong

 Uplift Modeling在工业界有哪些成功的应用? - 知乎

论文链接:

代码:https://github.com/kailiang-zhong/DESCN

摘要:https://github.com/kailiang-zhong/DESCN

因果推理在各个领域都有广泛的应用,如电子商务和精准医疗,其性能在很大程度上依赖于个体治疗效果(ITE)的准确估计。传统上,ITE 是通过在它们各自的样本空间中分别建模处理和控制响应函数来预测的。然而,这种方法在实践中通常会遇到两个问题,即由于治疗偏差而处理组和对照组之间的分布不同,并且它们的种群大小的样本不平衡显着。本文提出了一种深度全空间交叉网络(DESCN)从端到端角度对治疗效果进行建模。DESCN以多任务学习的方式通过交叉网络捕获治疗倾向、反应和隐藏治疗效果的综合信息。我们的方法联合学习整个样本空间中的治疗和响应函数,以避免治疗偏差,并采用中间伪治疗效果预测网络来缓解样本不平衡。在合成数据集和来自电子商务凭证分发业务的大规模生产数据集上进行了广泛的实验。结果表明,DESCN可以成功地提高ITE估计精度,提高提升排序性能。发布了生产数据集和源代码的样本,以促进社区未来的研究,据我们所知,这是第一个用于因果推理的大规模公共有偏见的处理数据集。

KEYWORDS

Causal Inference; Uplift Modeling; Individual Treatment Effect Estimation; Deep Learning; Entire Space Modeling

解决的问题:二元干预

目录

1介绍

2相关工作

3方法

3.1问题定义

3.2整个空间网络(ESN)

3.3 X-network

3.4 DESCN模型架构

 4实验

4.1 数据集

4.2比较方法

4.3评估指标 

4.4超参数设置 

5结果与讨论

6 结论和未来的工作


1介绍

个体层面的因果推理是一种预测分析技术,用于估计单一或多种治疗的个体治疗效果(ITE)。该技术具有广泛的应用,例如识别对患者最有效的药物”,并针对个性化保险产品优化交叉销售 [5]。在电子商务领域也很受欢迎,因为利润驱动的业务,如代金券分布和目标广告。在本文中,我们专注于 ITE 估计任务,其中仅存在单个处理(即处理或未处理3)。该技术具有广泛的应用,例如识别患者最有效的药物 [8],并针对个性化保险产品优化交叉销售 [5]。它在电子商务领域也很受欢迎,作为利润驱动的业务,如凭证分发和目标广告。

In this paper, we focus on the ITE estimation task where only a single treatment (i.e., treated or non-treated3) exists.

治疗效果建模(又称提升建模)的一个基本挑战是反事实结果的存在,也就是说,我们不能在单个完全相同的上下文中观察治疗和控制响应。因此,不能直接测量个体治疗效果。这一挑战的一个常见解决方案是分别在两个响应函数上构建模型,以得出反事实预测。

此外,在实践中遇到了两个额外的重要问题。首先,由于混杂因素的存在,治疗组组和对照组的潜在分布可能会有所不同。我们将此称为治疗偏差问题。当处理组和对照组不是随机选择的,而是由某些特定因素选择时,就会出现混淆因素。以电子商务凭证分发场景为例,某些折扣报价(也称为凭证)只能提供给非活动用户,旨在提高客户保留率。活跃用户,然而,出于节省成本的目的,没有给出凭证。因此,对照组的用户往往比治疗组的用户更活跃。它使模型难以通过损失函数学习这两组之间的无偏表示以进行ITE估计,如反事实回归(CFR)。

其次,由于特定的治疗策略,治疗组组和对照组之间的样本量在实践中可能会有很大差异,从而导致样本不平衡问题。在电子商务领域,自由成员(即治疗)可以提供给大多数用户群体以促进新产品。这使得治疗组的样本量远大于对照组。相比之下,在代金券分发场景中,凭证通常只分发给促销敏感的用户(即只有在给予促销时才会购买的用户),导致治疗组规模相对较小。这种不平衡的数据可能会使模型很难学习到对治疗效果的准确估计,并需要额外的校准工作。

为了解决上述问题,我们提出了一个多任务交叉网络来捕获治疗倾向、真实响应和名为深度整个空间交叉网络 (DESCN) 的集成框架中的伪治疗效果的因素。我们引入了一个整个空间网络 (ESN) 来共同学习整个样本空间中的处理和响应函数。ESN 不是单独建模单个样本空间中的治疗响应 (TR) 和控制响应 (CR),而是应用倾向网络来学习治疗倾向,然后与 TR 和 CR 连接以导出整个空间治疗反应 (ESTR) 和整个空间控制响应 (ESCR)。通过这种方式,该模型可以直接在 STR 和 WSCR 上进行训练,其中两个响应函数可以利用整个样本来解决治疗偏差问题。此外,它在共享结构之上引入了一个额外的新颖交叉网络(X 网络),以学习隐藏的伪处理效果,它充当中间变量来与两个真实响应函数(TR 和 CR)连接以模拟反事实响应。通过这种方式,我们可以学习一个包含所有响应和处理效果信息的集成表示来缓解样本不平衡问题。

本文的主要贡献如下:

•提出了一种端到端多任务交叉网络DESCN,该网络以综合的方式捕获治疗倾向、真实反应和伪治疗效果之间的关系,同时缓解治疗偏差和样本不平衡问题。在合成数据集和大规模生产数据集上进行的大量实验表明,DESCN 在 ITE 估计精度和提升排名性能方面均优于基线模型 +4% 到 +7%。

•用于对整个样本空间中处理和响应函数的联合分布进行建模的全空间网络 (ESN)。通过这种方式,我们可以以集成的方式推导出处理组和对照组的反事实信息,以利用整个样本并解决治疗偏差问题。请注意,ESN不仅限于DESCN模型,还可以应用于其他模型。现有的基于个体响应函数估计的提升模型。

•从电子商务平台Lazada收集了关于代金券分布的大规模生产数据集。我们特别设计了实验来生成训练集中的强治疗偏差,但在测试集中使用随机处理来更好地评估模型的性能。据我们所知,这是第一个同时具有训练和测试集中有偏和随机处理的工业生产数据集,我们希望这可以帮助促进未来因果推理的研究。

2相关工作

元学习 [7, 25, 26] 是估计 ITE 的流行框架,它使用任何机器学习估计器作为基础学习器,并通过比较基础模型返回的估计响应来识别治疗效果。S-learner 和 T-learner 是两种常用的元学习范式。S-learner 在整个空间上进行训练,结合处理样本和对照样本,将治疗指标作为加性输入特征。T-learner 使用在治疗和控制样本空间上单独构建的两个模型。在 S-learner 和 T-learner 中,当两组的人口规模不平衡时,相应基础模型之间的性能差异可能不一致并损害 ITE 估计性能。为了克服这个问题,提出了 X-learner [10],它采用从对照组中学习到的信息来更好地估计治疗组,反之亦然。X-learner 的第一步是分别学习两个类似 T-learner 的模型。然后,它计算观察结果与治疗组组和对照组的估计结果之间的差异作为估算治疗效果,进一步用于训练一对ITE估计器。最后,整体 ITE 计算为两个估计器的加权平均值。与其他元学习方法一样,基础模型的性能在 X-learner 中仍然有很强的级联效应

3方法

3.1问题定义

3.2整个空间网络(ESN)

3.3 X-network

3.4 DESCN模型架构

DESCN 的整体模型架构如图 1c 所示,它是整个空间网络 (ESN) 在 X 网络上的直接应用。这样可以缓解3.1中提到的治疗偏差和样本不平衡。此外,DESCN 中的共享网络可以同时学习倾向得分和控制响应,这也可以帮助学习捕获所有这些信息的综合表示。

 

 

 4实验

4.1 数据集

为了比较ITE估计精度和提升排序视角下的模型性能,我们使用了一个合成数据集,其中ground truth处理效果是可用的,另一个真实世界的数据集存在治疗偏差和样本不平衡问题。这两个数据集的统计数据总结在表 1 中。

第一个数据集是基于2019年美国因果推理会议(ACIC)数据挑战6的公共代码生成的合成数据集。我们使用提供的数据生成器处理器(DPG)从高维癫痫发作识别数据集(癫痫)[1]收集的协变量来生成40k个观察样本7。该合成数据集的优点是,地面真实治疗效果从DPG中完全知道,因此癫痫数据集是评估模型在ITE估计精度上的模型性能的理想选择。

为了进一步评估模型在提升排名性能方面的性能,我们使用了来自拉兹达的真实凭证分布业务场景的另一个大规模生产数据集,这是阿里巴巴集团领先的东南亚 (SEA) 电子商务平台。在实际生产环境中,由于操作靶向策略,治疗分配是选择性的,我们将这些数据收集为我们的训练集,其中包括强大的治疗偏差。我们还的用户规模略小,不受目标策略的影响,治疗分配遵循随机对照试验 (RCT),我们将它们用作测试数据集。

通过上述设置,训练集包含有偏差的处理数据,而测试集仅包含无偏的处理数据。在训练集中,由于处理偏差问题,处理和控制分布自然发散,两组之间的样本量在现实世界的场景中自然不同,如表1所示。

4.2比较方法

我们将 DESCN 与 ITE 估计中常用的以下模型进行比较,以评估其性能。所有模型都使用两个数据集中的所有输入密集协变量。

X-learner (NN)[10]:我们将 X-learner 作为具有代表性的元学习器。选择神经网络(NN)作为基础学习器,具有更好的非线性学习优势,并与我们提出的基于 NN 的模型进行公平比较。包括两个响应拟合模型、两个估算治疗效果拟合子模型和一个倾向模型。所有模型都在没有共享网络参数的情况下单独训练。估计的倾向分数用于最终的ITE输出以获得更好的性能。

•BART[2]:贝叶斯加性回归树(BART)是一个树和模型,用于评估非深度学习模型的性能。

•因果森林[22]:因果森林是一种基于非参数随机森林的树模型,直接估计治疗效果,这是基于树的抬升模型的另一个代表。

•TARNet[17]: TARNet是一种常用的深度学习抬升模型。与 X-learner (NN) 相比,它省略了拟合子模型的附加估算治疗效果,但引入了处理和控制响应网络的共享层。共享网络参数有助于缓解样本的不平衡。

•CFR[17]: CFR对TARNet应用额外的损失,迫使学习到的处理和控制协变量分布更接近。这可以帮助学习这两组之间的平衡表示并解决不平衡问题。

•x网络:x网络在第3.3节中介绍。它可以看作是我们最终模型DESCN的一个变体,其中ESN部分被删除。

4.3评估指标 

公式来源:Perfect Match: A Simple Method for Learning Representations For Counterfactual Inference With Neural Networks

4.4超参数设置 

我们在不同模型的公共组件中保持相同的超参数。具体来说,在癫痫数据集中,跨不同深度模型的所有神经网络由 128 个隐藏单元和 3 个全连接层组成。L2 正则化用于缓解系数为 0.01 的过度拟合,学习率设置为 0.001,没有衰减。所有模型都以 500 的批大小和 epoch 为 15 进行训练,唯一的区别是与每个特定网络设置相关的超参数,例如不同损失函数的权重。同样,对于生产数据集,所有神经网络也由 3 个全连接层组成。共享网络包含 128 个隐藏单元,其他子模型包含 64 个隐藏单元。L2 正则化值设置为 0.001,学习率设置为 0.001,衰减率为 0.95。所有模型都以 5000 和 5 个 epoch 的批大小进行训练。

5结果与讨论

在本节中,我们将介绍 DESCN 的整体性能以及要比较的其他方法。还提出了并研究了以下问题:

•Q1:提出的全空间网络(ESN)是否有助于提高整体性能,缓解治疗偏差问题?

•Q2:提出的交叉网络(X-network)和伪治疗效果(PTE)网络是否有助于学习平衡处理和控制响应函数。DESCN和其他基线模型的实验结果如表2所示。

6 结论和未来的工作

在本文中,我们提出了一种用于个体治疗效果 (ITE) 估计、深度全空间交叉网络 (DESCN) 的集成多任务模型。它引入了一个全空间网络(ESN),共同学习整个样本空间中处理和响应函数的分布,以解决治疗偏差问题。此外,提出了一种新颖的交叉网络(X-network),它能够直接对隐藏治疗效果进行建模,并且还具有减轻分布样本不平衡的优势。最后,DESCN以多任务学习的方式将ESN和X网络结合起来,通过交叉网络集成了治疗倾向、反应和隐藏治疗效果的信息。在改编自 ACIC 数据挑战和大规模凭证分布生产数据集的合成数据集中进行了广泛的实验。结果表明,DESCN 在合成数据集和生产数据集上分别提高了 +4% 到 +7%。这表明 DESCN 在 ITE 估计精度和提升排名方面是有效的。

探索有几个潜在的未来工作。在工业应用中,可以同时给出多种处理。通过将 X 网络扩展到多个响应函数,可以研究 DESCN 扩展到多处理设置。另一个方向是扩展 DESCN 以适应连续治疗结果,这也具有广泛的应用。

参考

KDD‘22 阿里|DESCN: Deep Entire Space Cross Networks | 多任务、端到端、IPW的共舞 - 知乎

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
网上购物商城数据库设计 一、概述 网上购物店的数据模型, 它主要模式有产品: product , 帐户: Account, 定单:Order。和产品相关的表有 category ,product,item, inventory, supplier;和用户相关表有的 account ,signon ,profile;和定单相关 的表有 orders,orderstatus,lineitem ,他们之间的整体关系如下. ERD 图 FK:Foreign Key 二、帐户模型 帐户模型,记录者用户的登录名称,密码。以及个人信息如地址,性 名,电话等,还有它在系统中的 profile 信息。表有 Account 主键是 userID,它记录用户的基本信息,如 email,name 等。Signon 表记录者 userID 和 password,Profile 表记录者用户的登录系统的系统设置。可 以根据用户的类型,显示不同的登录信息。 (1)account 表 create table account ( userid varchar(80) not null, email varchar(80) not null, name varchar(80) not null, status char(2) null, addr1 varchar(80) not null, addr2 varchar(40) null, city varchar(80) not null, state varchar(80) not null, zip varchar(20) not null, country varchar(20) not null, phone varchar(80) not null, constraint pk_account primary key (userid) ) 说明:primary key 是 userID,它记录帐户的基本信息。 (2)Signon 表 create table signon ( username varchar(25) not null, password varchar(25) not null, constraint pk_signon primary key (username) ) 说明:记录登录名和密码。 (3)Profile 表 create table profile ( userid varchar(80) not null, langpref varchar(80) not null, favcategory varchar(30), mylistopt int, banneropt int, constraint pk_profile primary key (userid) ) 说明:用户的登录信息,方便个性化定制。 (4)Bannerdata 表 create table bannerdata ( favcategory varchar(80) not null, bannername varchar(255) null, constraint pk_bannerdata primary key (favcategory) ) 说明:记录不同的登录信息。 三、产品模型 产品的模型主要有分类, 它是产品的大类。 表 category 就是记录分类 名称,描述信息。Product 记录每个产品的基本信息,包括产品名称,和产品的描述。它是一对 多的关系。Supplier 表 记录产品的提供者信息,包括提供者的名称,地址,状态等。Item 记 录产品的提供者,产 品 ID,价格,状态。Inventory 表记录产品的数量。关系如下: (1) category 表 create table category ( catid char(10) not null, name varchar(80) null, descn varchar(255) null, constraint pk_category primary key (catid) ) (2)product 表 create table product ( productid char(10) not null, category char(10) not null, name varchar(80) null, descn varchar(255) null, constraint pk_product primary key (productid), constraint fk_product_1 foreign key (category) references category (catid) ) (3) item 表 c

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值