生存回归(加速失效时间模型)算法原理及Spark MLlib调用实例(Scala/Java/python)

20 篇文章 5 订阅

生存回归(加速失效时间模型)

算法介绍:

        在spark.ml中,我们实施加速失效时间模型(Acceleratedfailure time),对于截尾数据它是一个参数化生存回归的模型。它描述了一个有对数生存时间的模型,所以它也常被称为生存分析的对数线性模型。与比例危险模型不同,因AFT模型中每个实例对目标函数的贡献是独立的,其更容易并行化。

         给定协变量的值 ,对于 可能的右截尾的随机生存时间 AFT模型下的似然函数如下:

 

其中 是指示器表明事件i发生了,即有无检测到。使 ,则对数似然函数为以下形式:

 

其中 是基线生存函数, 是对应的密度函数。

最常用的AFT模型基于韦伯分布的生存时间,生存时间的韦伯分布对应于生存时间对数的极值分布, 函数以及 函数如下:

  

  

生存时间服从韦伯分布的AFT模型的对数似然函数如下:

 

由于最小化对数似然函数的负数等于最大化后验概率,所以我们要优化的损失函数为 ,分别对 以及 求导:


可以证明AFT模型是一个凸优化问题,即是说找到凸函数的最小值取决于系数向量以及尺度参数的对数。在工具中实施的优化算法为L-BFGS

*当使用无拦截的连续非零列训练AFTSurvivalRegressionModel时,Spark MLlib为连续非零列输出零系数。这种处理与R中的生存函数survreg不同。

参数:

censorCol:

类型:字符串型。

含义:检查器列名。

featuresCol:

类型:字符串型。

含义:特征列名。

fitIntercept:

类型:布尔型。

含义:是否训练拦截对象。

labelCol:

类型:字符串型。

含义:标签列名。

maxIter:

类型:整数型。

含义:迭代次数(>=0)。

quantileProbabilities:

类型:双精度数组型。

含义:分位数概率数组。

quantilesCol:

类型:字符串型。

含义:分位数列名。

stepSize:

类型:双精度型。

含义:每次迭代优化步长。

tol:

类型:双精度型。

含义:迭代算法的收敛性。

示例:

Scala:

import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.regression.AFTSurvivalRegression

val training = spark.createDataFrame(Seq(
  (1.218, 1.0, Vectors.dense(1.560, -0.605)),
  (2.949, 0.0, Vectors.dense(0.346, 2.158)),
  (3.627, 0.0, Vectors.dense(1.380, 0.231)),
  (0.273, 1.0, Vectors.dense(0.520, 1.151)),
  (4.199, 0.0, Vectors.dense(0.795, -0.226))
)).toDF("label", "censor", "features")
val quantileProbabilities = Array(0.3, 0.6)
val aft = new AFTSurvivalRegression()
  .setQuantileProbabilities(quantileProbabilities)
  .setQuantilesCol("quantiles")

val model = aft.fit(training)

// Print the coefficients, intercept and scale parameter for AFT survival regression
println(s"Coefficients: ${model.coefficients} Intercept: " +
  s"${model.intercept} Scale: ${model.scale}")
model.transform(training).show(false)
Java:

import java.util.Arrays;
import java.util.List;

import org.apache.spark.ml.regression.AFTSurvivalRegression;
import org.apache.spark.ml.regression.AFTSurvivalRegressionModel;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

List<Row> data = Arrays.asList(
  RowFactory.create(1.218, 1.0, Vectors.dense(1.560, -0.605)),
  RowFactory.create(2.949, 0.0, Vectors.dense(0.346, 2.158)),
  RowFactory.create(3.627, 0.0, Vectors.dense(1.380, 0.231)),
  RowFactory.create(0.273, 1.0, Vectors.dense(0.520, 1.151)),
  RowFactory.create(4.199, 0.0, Vectors.dense(0.795, -0.226))
);
StructType schema = new StructType(new StructField[]{
  new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
  new StructField("censor", DataTypes.DoubleType, false, Metadata.empty()),
  new StructField("features", new VectorUDT(), false, Metadata.empty())
});
Dataset<Row> training = spark.createDataFrame(data, schema);
double[] quantileProbabilities = new double[]{0.3, 0.6};
AFTSurvivalRegression aft = new AFTSurvivalRegression()
  .setQuantileProbabilities(quantileProbabilities)
  .setQuantilesCol("quantiles");

AFTSurvivalRegressionModel model = aft.fit(training);

// Print the coefficients, intercept and scale parameter for AFT survival regression
System.out.println("Coefficients: " + model.coefficients() + " Intercept: "
  + model.intercept() + " Scale: " + model.scale());
model.transform(training).show(false);
Python:

from pyspark.ml.regression import AFTSurvivalRegression
from pyspark.ml.linalg import Vectors

training = spark.createDataFrame([
    (1.218, 1.0, Vectors.dense(1.560, -0.605)),
    (2.949, 0.0, Vectors.dense(0.346, 2.158)),
    (3.627, 0.0, Vectors.dense(1.380, 0.231)),
    (0.273, 1.0, Vectors.dense(0.520, 1.151)),
    (4.199, 0.0, Vectors.dense(0.795, -0.226))], ["label", "censor", "features"])
quantileProbabilities = [0.3, 0.6]
aft = AFTSurvivalRegression(quantileProbabilities=quantileProbabilities,
                            quantilesCol="quantiles")

model = aft.fit(training)

# Print the coefficients, intercept and scale parameter for AFT survival regression
print("Coefficients: " + str(model.coefficients))
print("Intercept: " + str(model.intercept))
print("Scale: " + str(model.scale))
model.transform(training).show(truncate=False)


  • 1
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值