Spark2 生存分析Survival regression

  在spark.ml中,实现了加速失效时间(AFT)模型,这是一个用于检查数据的参数生存回归模型。 它描述了生存时间对数的模型,因此它通常被称为生存分析的对数线性模型。 不同于为相同目的设计的比例风险模型,AFT模型更容易并行化,因为每个实例独立地贡献于目标函数。

  当在具有常量非零列的数据集上匹配AFTSurvivalRegressionModel而没有截距时,Spark MLlib为常量非零列输出零系数。 这种行为不同于R survival :: survreg。

导入包 

1
2
3
4
5
6
7
8
9
10
11
12
import  org.apache.spark.sql.SparkSession
import  org.apache.spark.sql.Dataset
import  org.apache.spark.sql.Row
import  org.apache.spark.sql.DataFrame
import  org.apache.spark.sql.functions. _
 
import  org.apache.spark.ml.linalg.Vectors
import  org.apache.spark.ml.feature.VectorAssembler
import  org.apache.spark.ml.Pipeline
import  org.apache.spark.ml.evaluation.RegressionEvaluator
import  org.apache.spark.ml.regression.AFTSurvivalRegression
import  org.apache.spark.ml.tuning.{ CrossValidator, ParamGridBuilder }

 

导入样本数据

 

建模并调优

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
val  colArray  =  Array( "race" "poverty" "smoke" "alcohol" "agemth" "ybirth" "yschool" "pc3mth" )
  
val  assembler  =  new  VectorAssembler().setInputCols(colArray).setOutputCol( "features" )
  
val  vecDF :  DataFrame  =  assembler.transform(data)
  
val  Array(trainingDF, testDF)  =  vecDF.randomSplit(Array( 0.7 0.3 ))
  
//###########################
// 建立生存回归模型
val  AFT  =  new  AFTSurvivalRegression().setFeaturesCol( "features" ).setLabelCol( "label" ).setCensorCol( "censor" ).fit(trainingDF)
  
// 设置管道
val  pipeline  =  new  Pipeline().setStages(Array(AFT))
  
// 设置参数网格
val  paramGrid  =  new  ParamGridBuilder().addGrid(AFT.maxIter, Array( 100 500 1000 )).addGrid(AFT.tol, Array( 1 E- 2 1 E- 6 )).build()
  
// 选择(prediction, true label),计算测试误差。
// 注意RegEvaluator.isLargerBetter,评估的度量值是大的好,还是小的好,系统会自动识别
val  RegEvaluator  =  new  RegressionEvaluator().setLabelCol(AFT.getLabelCol).setPredictionCol(AFT.getPredictionCol).setMetricName( "rmse" )
  
// 设置交叉验证
val  cv  =  new  CrossValidator().setEstimator(pipeline).setEvaluator(RegEvaluator).setEstimatorParamMaps(paramGrid).setNumFolds( 3 )
  
// 执行交叉验证,并选择出最好的参数集
val  cvModel  =  cv.fit(trainingDF)
  
// 查看全部参数
cvModel.extractParamMap()
// cvModel.avgMetrics.length=cvModel.getEstimatorParamMaps.length
// cvModel.avgMetrics与cvModel.getEstimatorParamMaps中的元素一一对应
cvModel.avgMetrics.length
cvModel.avgMetrics  // 参数对应的平均度量
  
cvModel.getEstimatorParamMaps.length
cvModel.getEstimatorParamMaps  // 参数组合的集合
  
cvModel.getEvaluator.extractParamMap()   // 评估的参数
  
cvModel.getEvaluator.isLargerBetter  // 评估的度量值是大的好,还是小的好
cvModel.getNumFolds  // 交叉验证的折数
  
//################################
// 测试模型
val  predictDF :  DataFrame  =  cvModel.transform(testDF).selectExpr(
   //"race","poverty","smoke","alcohol","agemth","ybirth","yschool","pc3mth", "features",
   "label" "censor" ,
   "round(prediction,2) as prediction" ).orderBy( "label" )
predictDF.show
  
spark.stop()

 

代码执行结果 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
// 查看全部参数
cvModel.extractParamMap()
res 2 :  org.apache.spark.ml.param.ParamMap  =
{
     cv _ baf 8 c 9 af 33 b 7 -estimator :  pipeline _ 20 ba 567066 f 7 ,
     cv _ baf 8 c 9 af 33 b 7 -estimatorParamMaps :  [Lorg.apache.spark.ml.param.ParamMap; @ 412 a 07 c 8 ,
     cv _ baf 8 c 9 af 33 b 7 -evaluator :  regEval _ 59075079 f 1 c 9 ,
     cv _ baf 8 c 9 af 33 b 7 -numFolds :  3 ,
     cv _ baf 8 c 9 af 33 b 7 -seed :  - 1191137437
}
  
// cvModel.avgMetrics.length=cvModel.getEstimatorParamMaps.length
// cvModel.avgMetrics与cvModel.getEstimatorParamMaps中的元素一一对应
cvModel.avgMetrics.length
res 3 :  Int  =  6
  
cvModel.avgMetrics  // 参数对应的平均度量
res 4 :  Array[Double]  =  Array( 18.53 17.53 19.53 17.63 18.53 18.93 )
  
cvModel.getEstimatorParamMaps.length
res 5 :  Int  =  6
  
cvModel.getEstimatorParamMaps  // 参数组合的集合
res 6 :  Array[org.apache.spark.ml.param.ParamMap]  =
Array({
     aftSurvReg _ a 7 e 5 bc 450599 -maxIter :  100 ,
     aftSurvReg _ a 7 e 5 bc 450599 -tol :  0.01
}, {
     aftSurvReg _ a 7 e 5 bc 450599 -maxIter :  100 ,
     aftSurvReg _ a 7 e 5 bc 450599 -tol :  1.0 E- 6
}, {
     aftSurvReg _ a 7 e 5 bc 450599 -maxIter :  500 ,
     aftSurvReg _ a 7 e 5 bc 450599 -tol :  0.01
}, {
     aftSurvReg _ a 7 e 5 bc 450599 -maxIter :  500 ,
     aftSurvReg _ a 7 e 5 bc 450599 -tol :  1.0 E- 6
}, {
     aftSurvReg _ a 7 e 5 bc 450599 -maxIter :  1000 ,
     aftSurvReg _ a 7 e 5 bc 450599 -tol :  0.01
}, {
     aftSurvReg _ a 7 e 5 bc 450599 -maxIter :  1000 ,
     aftSurvReg _ a 7 e 5 bc 450599 -tol :  1.0 E- 6
})
  
cvModel.getEvaluator.extractParamMap()   // 评估的参数
res 7 :  org.apache.spark.ml.param.ParamMap  =
{
     regEval _ 59075079 f 1 c 9 -labelCol :  label,
     regEval _ 59075079 f 1 c 9 -metricName :  rmse,
     regEval _ 59075079 f 1 c 9 -predictionCol :  prediction
}
  
cvModel.getEvaluator.isLargerBetter  // 评估的度量值是大的好,还是小的好
res 8 :  Boolean  =  false    // 这里显示“评估的度量值”是小的好
  
cvModel.getNumFolds  // 交叉验证的折数
res 9 :  Int  =  3
  
//################################
// 测试模型
val  predictDF :  DataFrame  =  cvModel.transform(testDF).selectExpr(
      |    //"race","poverty","smoke","alcohol","agemth","ybirth","yschool","pc3mth", "features",
      |    "label" "censor" ,
      |    "round(prediction,2) as prediction" ).orderBy( "label" )
predictDF :  org.apache.spark.sql.DataFrame  =  [label :  double, censor :  double ...  1  more field]
  
predictDF.show
+-----+------+----------+                                                     
|label|censor|prediction|
+-----+------+----------+
|   1.0 |    1.0 |       15.4 |
|   1.0 |    1.0 |      20.02 |
|   1.0 |    1.0 |      18.73 |
|   1.0 |    1.0 |      21.58 |
|   1.0 |    1.0 |       21.8 |
|   1.0 |    1.0 |       21.8 |
|   1.0 |    1.0 |      14.37 |
|   1.0 |    1.0 |       13.5 |
|   1.0 |    1.0 |      15.82 |
|   1.0 |    1.0 |      19.51 |
|   1.0 |    1.0 |      13.17 |
|   1.0 |    1.0 |       11.9 |
|   1.0 |    1.0 |      17.26 |
|   1.0 |    1.0 |      13.57 |
|   1.0 |    1.0 |      11.57 |
|   1.0 |    1.0 |      13.55 |
|   1.0 |    1.0 |      10.95 |
|   1.0 |    1.0 |      14.92 |
|   1.0 |    1.0 |      12.25 |
|   1.0 |    1.0 |      19.62 |
+-----+------+----------+
only showing top  20  rows

文章出处:http://www.cnblogs.com/wwxbi/p/6150352.html
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值