在spark.ml中,实现了加速失效时间(AFT)模型,这是一个用于检查数据的参数生存回归模型。 它描述了生存时间对数的模型,因此它通常被称为生存分析的对数线性模型。 不同于为相同目的设计的比例风险模型,AFT模型更容易并行化,因为每个实例独立地贡献于目标函数。
当在具有常量非零列的数据集上匹配AFTSurvivalRegressionModel而没有截距时,Spark MLlib为常量非零列输出零系数。 这种行为不同于R survival :: survreg。
导入包
1
2
3
4
5
6
7
8
9
10
11
12
|
import
org.apache.spark.sql.SparkSession
import
org.apache.spark.sql.Dataset
import
org.apache.spark.sql.Row
import
org.apache.spark.sql.DataFrame
import
org.apache.spark.sql.functions.
_
import
org.apache.spark.ml.linalg.Vectors
import
org.apache.spark.ml.feature.VectorAssembler
import
org.apache.spark.ml.Pipeline
import
org.apache.spark.ml.evaluation.RegressionEvaluator
import
org.apache.spark.ml.regression.AFTSurvivalRegression
import
org.apache.spark.ml.tuning.{ CrossValidator, ParamGridBuilder }
|
导入样本数据
建模并调优
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
|
val
colArray
=
Array(
"race"
,
"poverty"
,
"smoke"
,
"alcohol"
,
"agemth"
,
"ybirth"
,
"yschool"
,
"pc3mth"
)
val
assembler
=
new
VectorAssembler().setInputCols(colArray).setOutputCol(
"features"
)
val
vecDF
:
DataFrame
=
assembler.transform(data)
val
Array(trainingDF, testDF)
=
vecDF.randomSplit(Array(
0.7
,
0.3
))
//###########################
// 建立生存回归模型
val
AFT
=
new
AFTSurvivalRegression().setFeaturesCol(
"features"
).setLabelCol(
"label"
).setCensorCol(
"censor"
).fit(trainingDF)
// 设置管道
val
pipeline
=
new
Pipeline().setStages(Array(AFT))
// 设置参数网格
val
paramGrid
=
new
ParamGridBuilder().addGrid(AFT.maxIter, Array(
100
,
500
,
1000
)).addGrid(AFT.tol, Array(
1
E-
2
,
1
E-
6
)).build()
// 选择(prediction, true label),计算测试误差。
// 注意RegEvaluator.isLargerBetter,评估的度量值是大的好,还是小的好,系统会自动识别
val
RegEvaluator
=
new
RegressionEvaluator().setLabelCol(AFT.getLabelCol).setPredictionCol(AFT.getPredictionCol).setMetricName(
"rmse"
)
// 设置交叉验证
val
cv
=
new
CrossValidator().setEstimator(pipeline).setEvaluator(RegEvaluator).setEstimatorParamMaps(paramGrid).setNumFolds(
3
)
// 执行交叉验证,并选择出最好的参数集
val
cvModel
=
cv.fit(trainingDF)
// 查看全部参数
cvModel.extractParamMap()
// cvModel.avgMetrics.length=cvModel.getEstimatorParamMaps.length
// cvModel.avgMetrics与cvModel.getEstimatorParamMaps中的元素一一对应
cvModel.avgMetrics.length
cvModel.avgMetrics
// 参数对应的平均度量
cvModel.getEstimatorParamMaps.length
cvModel.getEstimatorParamMaps
// 参数组合的集合
cvModel.getEvaluator.extractParamMap()
// 评估的参数
cvModel.getEvaluator.isLargerBetter
// 评估的度量值是大的好,还是小的好
cvModel.getNumFolds
// 交叉验证的折数
//################################
// 测试模型
val
predictDF
:
DataFrame
=
cvModel.transform(testDF).selectExpr(
//"race","poverty","smoke","alcohol","agemth","ybirth","yschool","pc3mth", "features",
"label"
,
"censor"
,
"round(prediction,2) as prediction"
).orderBy(
"label"
)
predictDF.show
spark.stop()
|
代码执行结果
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
|
// 查看全部参数
cvModel.extractParamMap()
res
2
:
org.apache.spark.ml.param.ParamMap
=
{
cv
_
baf
8
c
9
af
33
b
7
-estimator
:
pipeline
_
20
ba
567066
f
7
,
cv
_
baf
8
c
9
af
33
b
7
-estimatorParamMaps
:
[Lorg.apache.spark.ml.param.ParamMap;
@
412
a
07
c
8
,
cv
_
baf
8
c
9
af
33
b
7
-evaluator
:
regEval
_
59075079
f
1
c
9
,
cv
_
baf
8
c
9
af
33
b
7
-numFolds
:
3
,
cv
_
baf
8
c
9
af
33
b
7
-seed
:
-
1191137437
}
// cvModel.avgMetrics.length=cvModel.getEstimatorParamMaps.length
// cvModel.avgMetrics与cvModel.getEstimatorParamMaps中的元素一一对应
cvModel.avgMetrics.length
res
3
:
Int
=
6
cvModel.avgMetrics
// 参数对应的平均度量
res
4
:
Array[Double]
=
Array(
18.53
,
17.53
,
19.53
,
17.63
,
18.53
,
18.93
)
cvModel.getEstimatorParamMaps.length
res
5
:
Int
=
6
cvModel.getEstimatorParamMaps
// 参数组合的集合
res
6
:
Array[org.apache.spark.ml.param.ParamMap]
=
Array({
aftSurvReg
_
a
7
e
5
bc
450599
-maxIter
:
100
,
aftSurvReg
_
a
7
e
5
bc
450599
-tol
:
0.01
}, {
aftSurvReg
_
a
7
e
5
bc
450599
-maxIter
:
100
,
aftSurvReg
_
a
7
e
5
bc
450599
-tol
:
1.0
E-
6
}, {
aftSurvReg
_
a
7
e
5
bc
450599
-maxIter
:
500
,
aftSurvReg
_
a
7
e
5
bc
450599
-tol
:
0.01
}, {
aftSurvReg
_
a
7
e
5
bc
450599
-maxIter
:
500
,
aftSurvReg
_
a
7
e
5
bc
450599
-tol
:
1.0
E-
6
}, {
aftSurvReg
_
a
7
e
5
bc
450599
-maxIter
:
1000
,
aftSurvReg
_
a
7
e
5
bc
450599
-tol
:
0.01
}, {
aftSurvReg
_
a
7
e
5
bc
450599
-maxIter
:
1000
,
aftSurvReg
_
a
7
e
5
bc
450599
-tol
:
1.0
E-
6
})
cvModel.getEvaluator.extractParamMap()
// 评估的参数
res
7
:
org.apache.spark.ml.param.ParamMap
=
{
regEval
_
59075079
f
1
c
9
-labelCol
:
label,
regEval
_
59075079
f
1
c
9
-metricName
:
rmse,
regEval
_
59075079
f
1
c
9
-predictionCol
:
prediction
}
cvModel.getEvaluator.isLargerBetter
// 评估的度量值是大的好,还是小的好
res
8
:
Boolean
=
false
// 这里显示“评估的度量值”是小的好
cvModel.getNumFolds
// 交叉验证的折数
res
9
:
Int
=
3
//################################
// 测试模型
val
predictDF
:
DataFrame
=
cvModel.transform(testDF).selectExpr(
|
//"race","poverty","smoke","alcohol","agemth","ybirth","yschool","pc3mth", "features",
|
"label"
,
"censor"
,
|
"round(prediction,2) as prediction"
).orderBy(
"label"
)
predictDF
:
org.apache.spark.sql.DataFrame
=
[label
:
double, censor
:
double ...
1
more field]
predictDF.show
+-----+------+----------+
|label|censor|prediction|
+-----+------+----------+
|
1.0
|
1.0
|
15.4
|
|
1.0
|
1.0
|
20.02
|
|
1.0
|
1.0
|
18.73
|
|
1.0
|
1.0
|
21.58
|
|
1.0
|
1.0
|
21.8
|
|
1.0
|
1.0
|
21.8
|
|
1.0
|
1.0
|
14.37
|
|
1.0
|
1.0
|
13.5
|
|
1.0
|
1.0
|
15.82
|
|
1.0
|
1.0
|
19.51
|
|
1.0
|
1.0
|
13.17
|
|
1.0
|
1.0
|
11.9
|
|
1.0
|
1.0
|
17.26
|
|
1.0
|
1.0
|
13.57
|
|
1.0
|
1.0
|
11.57
|
|
1.0
|
1.0
|
13.55
|
|
1.0
|
1.0
|
10.95
|
|
1.0
|
1.0
|
14.92
|
|
1.0
|
1.0
|
12.25
|
|
1.0
|
1.0
|
19.62
|
+-----+------+----------+
only showing top
20
rows
|
文章出处:http://www.cnblogs.com/wwxbi/p/6150352.html