使用 mlr3proba 的生存网络

我收到了许多关于 R 中的生存神经网络(“生存网络”)的问题,包括“这可能吗?”到“如何在 R 中安装 Python?”以及“我如何调整这些模型?”。如果您是对生存网络感兴趣的 R 用户,那么这篇文章就是适合您的!这不是如何使用相关包的教程,而是回答这些问题的演示。

这是一个高级演示,我假设您知道: i) 什么是生存分析; ii) 什么是神经网络(以及常见的超参数); iii) 基本的机器学习 (ML) 方法,例如重采样和调整。如果需要,我很乐意在以后的文章中全面介绍这些主题。

在本文中,我们将介绍如何: i) 在 R 中安装 Python 模块; ii) 将生存模型(Sonabend 2020 )中实现的模型与mlr3proba (Sonabend 等人2021 )结合使用; iii) 使用mlr3tuning调整模型(Lang, Richter, et al. 2019并使用mlr3pipelines预处理数据(Binder et al. 2019; iv) 在mlr3proba中对模型进行基准测试和比较; v) 在mlr3benchmark中分析结果(Sonabend 和 Pfisterer 2020。其中许多软件包都属于 mlr3 系列,如果您想了解有关它们的更多信息,我建议您从 mlr3book (Becker et al. 2021 a )开始。

本演示中的代码是一个“玩具”示例,选择在我非常旧的笔记本电脑上快速运行代码,所有型号的性能预计都会很差。

让我们深度学习吧!

安装软件包

我们将使用多个软件包,请确保安装以下软件包:

options(repos=c(
  mlrorg = 'https://mlr-org.r-universe.dev',
  raphaels1 = 'https://raphaels1.r-universe.dev',
  CRAN = 'https://cloud.r-project.org'
))
install.packages(c("ggplot2", "mlr3benchmark", "mlr3pipelines", "mlr3proba", "mlr3tuning", 
                   "survivalmodels", "mlr3extralearners"))
install.packages(c("ggplot2", "mlr3benchmark", "mlr3pipelines", "mlr3proba", 
                   "mlr3tuning", "survivalmodels"))
remotes::install_github("mlr-org/mlr3extralearners")

我已经安装了以下版本:

          ggplot2     mlr3benchmark mlr3extralearners     mlr3pipelines 
          "3.3.3"           "0.1.2"      "0.3.0.9000"           "0.3.4" 
        mlr3proba        mlr3tuning        reticulate    survivalmodels 
          "0.3.2"      "0.8.0.9000"            "1.19"           "0.1.8" 

生存模型

The package {survivalmodels} currently contains the neural networks:

  • CoxTime⁷
  • DeepHit⁸
  • DeepSurv⁹
  • Logistic-Hazard¹⁰ ¹¹
  • PCHazard¹¹
  • DNNSurv¹²

生存模型包当前包含神经网络:

  • CoxTime (Kvamme、Borgan 和 Scheel 2019
  • DeepHit (Lee 等人,2018
  • DeepSurv (Katzman 等人,2018
  • 物流风险(Gensheimer 和 Narasimhan 2019 (Kvamme 和 Borgan 2019
  • PCHazard (Kvamme 和 Borgan 2019
  • DNNSurv (赵和冯2020

其中前五个使用reticulate (Ushey、Allaire 和 Tang 2020连接由 Håvard Kvamme 编写的出色的 Python pycox (Kvamme 2018包,这意味着您可以以 Python 的速度在 R 中使用神经网络。 DNNSurv 使用 R keras (Allaire 和 Chollet 2020包。

在本文中,我们将只关注前五个,因为它们在文献中得到了更好的确立,并且它们具有相同的界面,这简化了调整,如下所示。本文中没有提供网络的描述,但如果有要求,我很乐意在以后的帖子中详细描述这些网络。

在 R 中使用 Python

要在生存模型中使用 Python 模型,您需要在 R 中设置 Miniconda 环境并安装所需的 Python 模块。使用survivalmodels的功能可以安装所需的模块:

library(survivalmodels)

install_pycox(pip = TRUE, install_torch = TRUE)
install_keras(pip = TRUE, install_tensorflow = TRUE)

install_pycox用于reticulate::py_install安装 Python 包pycox和可选的pycox (Paszke et al. 2017 ) ( install_torch = TRUE)。install_keras将安装pycox和可选的pycox (Abadi et al. 2015 ) ( install_tensorflow = TRUE)。

播种

确保 Python 实现的模型的可重复结果比平常稍微棘手一些,因为必须在多个位置设置种子。生存模型通过一个名为 的函数简化了这一过程set_seed

set_seed(1234)

MLR3概率

为了在安装后运行这些模型,我们将使用不同的界面。Survivalmodels的功能有限,这对于基本模型拟合/预测来说是可以的,但是神经网络通常需要数据预处理和模型调整,因此我们将使用mlr3proba ,它是mlr3的一部分(Lang、Binder 等) al. 2019软件包系列,包括概率监督学习的功能,生存分析是其中的一部分。mlr3包使用 R6 (Chang 2018 )接口在 R 中进行面向对象的机器学习。mlr3 的完整教程可以在 mlr3book 中找到还有一章使用mlr3proba进行生存分析(Becker et al. 2021 b )

现在让我们开始实验吧!

生存数据

我们需要做的第一件事是获取一些生存数据集来训练我们的模型,在mlr3proba数据集中保存在包含有关特征和目标的信息的任务中。我们将使用mlr3probawhas、 和我们自己设置的一项任务(尽管在mlr3proba中也已经可用,但这只是示例)。

library(mlr3)
library(mlr3proba)

## get the `whas` task from mlr3proba
whas <- tsk("whas")

## create our own task from the rats dataset
rats_data <- survival::rats
## convert characters to factors
rats_data$sex <- factor(rats_data$sex, levels = c("f", "m"))
rats <- TaskSurv$new("rats", rats_data, time = "time", event = "status")

## combine in list
tasks <- list(whas, rats)

获取和调整学习者

现在是您来这里的部分!我们将在生存模型(除了 DNNSurv 之外的所有模型)中训练和调整 Pycox 神经网络。调整由mlr3tuning包处理。我们不会为模型指定自定义架构,而是使用默认值,如果您熟悉 PyTorch,那么您可以选择创建自己的架构(如果您愿意)将其传递给模型中的 custom_net 参数。

超参数配置

训练和调整神经网络是一门艺术,但在本文中,我们将保持简单。我们将使用以下配置调整神经网络:

  • 辍学分数调整为 [0, 1]
  • 权重衰减超过 [0, 0.5]
  • 学习率超过 [0, 1]
  • {1,…,32} 层中的节点数
  • {1,…,4} 上的隐藏层数量

为了进行设置,我们使用paradox (Lang, Bischl, et al. 2019 )包(也是mlr3的一部分)来创建超参数搜索空间。生存模型中的所有 Pycox 学习器都具有相同的参数接口,因此只需提供一个搜索空间。在生存模型中,节点数num_nodes被指定为任意长度的向量,该向量不可直接调节。因此,我们分别调整层中的节点数nodes和层数 ,k然后提供一个转换来将两者结合起来。

library(paradox)
search_space = ps(
  # p_dbl for numeric valued parameters
  dropout = p_dbl(lower = 0, upper = 1),
  weight_decay = p_dbl(lower = 0, upper = 0.5),
  learning_rate = p_dbl(lower = 0, upper = 1),
  
  # p_int for integer valued parameters
  nodes = p_int(lower = 1, upper = 32),
  k = p_int(lower = 1, upper = 4)
)

search_space$trafo = function(x, param_set) {
  x$num_nodes = rep(x$nodes, x$k)
  x$nodes = x$k = NULL
  return(x)
}

请注意,在我们的转换中,我们假设每层的节点数相同,这是一个相当常见的假设,但可以考虑更高级的转换。

现在,我们将学习器包装在 中AutoTuner,这使得学习器可以在基准实验中轻松调整。当我们调整多个相似的学习器时,我们可以创建一个函数,使创建 AutoTuner 变得更容易。为了进行调整,我们使用:2/3 分割保留、c 索引优化和 2 次迭代随机搜索。这些设置不应该在实践中使用,只是为了让演示运行得更快,在实践中我通常推荐 3 重嵌套交叉验证、rsmp("cv", folds = 3)60 次迭代随机搜索(Bergstra 和 Bengio 2012、。trm("evals", n_evals = 60)

library(mlr3tuning)

create_autotuner <- function(learner) {
  AutoTuner$new(
  learner = learner,
  search_space = search_space,
  resampling = rsmp("holdout"),
  measure = msr("surv.cindex"),
  terminator = trm("evals", n_evals = 2),
  tuner = tnr("random_search"))
}

现在让我们的学习者并应用我们的函数。对于所有学习者,我们将设置以下超参数:

  • 30%的嵌套训练数据将被保留作为early_stopping的验证数据,frac = 0.3, early_stopping = TRUE
  • 亚当优化器,optimizer = “adam"
  • 最多 10 个 epoch,epochs = 10

由于我们使用提前停止,epoch 的数量通常会大量增加(比如最少 100),但这里再次减少以运行得更快。所有其他超参数均使用模型默认值。

## learners are stored in mlr3extralearners
library(mlr3extralearners)

## load learners
learners <- lrns(paste0("surv.", c("coxtime", "deephit", "deepsurv", 
                                   "loghaz", "pchazard")),
                 frac = 0.3, early_stopping = TRUE, epochs = 10,
                 optimizer = "adam"
)
 
# apply our function
learners <- lapply(learners, create_autotuner)

预处理

所有神经网络都需要一些数据预处理。通过mlr3pipelines包,特别是encode和pipelineops,这可以变得简单,scale它们分别执行 one-hot 编码和特征标准化(通过更改参数可以使用其他方法)。我们将再次创建一个可以应用于所有学习者的函数。

library(mlr3pipelines)

create_pipeops <- function(learner) {
  po("encode") %>>% po("scale") %>>% po("learner", learner)
}

# apply our function
learners <- lapply(learners, create_pipeops)

基准

我们准备好了!对于我们的实验,我们将使用 3 倍交叉验证,但通常首选 5 倍交叉验证rsmp("cv", folds = 5)。为了进行比较,我们还将在实验中添加 Kaplan-Meier (Kaplan 和 Meier 1958和 Cox PH (Cox 1972学习器。我们将把我们的基准结果与 Harrell 的 C 指数(Harrell、Califf 和 Pryor 1982和综合 Graf 评分(Graf 等人1999(还有许多其他衡量标准)进行汇总。

## select holdout as the resampling strategy
resampling <- rsmp("cv", folds = 3)

## add KM and CPH
learners <- c(learners, lrns(c("surv.kaplan", "surv.coxph")))
design <- benchmark_grid(tasks, learners, resampling)
bm <- benchmark(design)

我们可以通过不同的衡量标准来汇总结果:

## Concordance index and Integrated Graf Score
msrs <- msrs(c("surv.cindex", "surv.graf"))
bm$aggregate(msrs)[, c(3, 4, 7, 8)]
    task_id                       learner_id surv.harrell_c  surv.graf
 1:    whas  encode.scale.surv.coxtime.tuned      0.5402544 0.23660491
 2:    whas  encode.scale.surv.deephit.tuned      0.4930036 0.39716194
 3:    whas encode.scale.surv.deepsurv.tuned      0.6951130 0.21126205
 4:    whas   encode.scale.surv.loghaz.tuned      0.5116887 0.28632792
 5:    whas encode.scale.surv.pchazard.tuned      0.5247859 0.29908496
 6:    whas                      surv.kaplan      0.5000000 0.23717567
 7:    whas                       surv.coxph      0.7391565 0.20120291
 8:    rats  encode.scale.surv.coxtime.tuned      0.6309753 0.05685849
 9:    rats  encode.scale.surv.deephit.tuned      0.3769910 0.36025715
10:    rats encode.scale.surv.deepsurv.tuned      0.5693399 0.05736899
11:    rats   encode.scale.surv.loghaz.tuned      0.6154914 0.15546280
12:    rats encode.scale.surv.pchazard.tuned      0.6745137 0.34789168
13:    rats                      surv.kaplan      0.5000000 0.05799158
14:    rats                       surv.coxph      0.7762968 0.05321489

在我们的玩具演示中,我们可以从这些结果中初步得出结论:Cox PH 表现最好,而 DeepHit 表现最差。

分析结果

由于我们已在多个独立数据集上运行模型,因此我们可以与mlr3benchmark更详细地比较我们的结果。下面带注释的代码只是展示了可能的内容,但未提供任何详细信息(如果您对此感兴趣,请在未来的教程中告诉我!)。

library(mlr3benchmark)

## create mlr3benchmark object
bma <- as.BenchmarkAggr(bm, 
                        measures = msrs(c("surv.cindex", "surv.graf")))

## run global Friedman test
bma$friedman_test()
                X2 df    p.value p.signif
harrell_c     10.5  6  0.1051144         
graf      11.78571  6 0.06692358        .

Friedman 检验结果表明,模型之间在任一测量方面均没有显着差异(假设 p ≤ 0.05 是显着的)。现在,假设如果 p ≤ 0.1,模型就会显着不同(我一般不建议这样做),这样我们就可以看一下关键差异图(Demšar 2006 )来比较这些模型。

## load ggplot2 for autoplots
library(ggplot2)

## critical difference diagrams for IGS
autoplot(bma, meas = "graf", type = "cd", ratio = 1/3, p.value = 0.1)

结果表明,没有模型优于 Kaplan-Meier 基线,并且我们的分析已完成(对于这个玩具设置来说并不奇怪!)。

概括

在本演示中,我们使用了用 Python 实现并通过生存模型进行交互的神经网络。我们使用mlr3proba接口来加载这些模型并获得一些生存任务。我们使用mlr3tuning来设置超参数配置和调整控制,并使用mlr3pipelines进行数据预处理。最后,我们使用mlr3benchmark分析多个数据集的结果。我希望本文能够演示 mlr3 界面如何使从生存模型中选择、调整和比较模型变得更加简单。

  • 10
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

皮肤小白生

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值