此版本是ml版本,区别于mllib版本的决策树api
输入
Param name
Type(s)
Default
Description
labelCol
Double
“label”
标签
featuresCol
Vector
“features”
特征向量
输出
Param name
Type(s)
Default
Description
Notes
predictionCol
Double
“prediction”
预测结果标签
rawPredictionCol
Vector
“rawPrediction”
Vector of length # classes, with the counts of training instance labels at the tree node which makes the prediction
仅限分类
probabilityCol
Vector
“probability”
Vector of length # classes equal to rawPrediction normalized to a multinomial distribution
仅限分类
varianceCol
Double
预测结果方差
仅限回归
API函数
决策树分类器
classpyspark.ml.classification.DecisionTreeClassifier(self, featuresCol=”features”, labelCol=”label”, predictionCol=”prediction”, probabilityCol=”probability”, rawPredictionCol=”rawPrediction”, maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity=”gini”)
maxDepth=5表示树的深度最大为5
maxBins=32表示离散化连续变量分区个数最大值
imputity=”gini”决策树节点特征选择方式,类似还有c45等
代码实例
>>> from pyspark.mllib.linalg import Vectors
>>> from pyspark.ml.feature import StringIndexer
>>> df = sqlContext.createDataFrame([
... (1.0, Vectors.dense(1.0)),
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
>>> si_model = stringIndexer.fit(df)
>>> td = si_model.transform(df)
>>> dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed")
>>> model = dt.fit(td)
>>> model.numNodes
3
>>> model.depth
1
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> result = model.transform(test0).head()
>>> result.prediction
0.0
>>> result.probability
DenseVector([1.0, 0.0])
>>> result.rawPrediction
DenseVector([1.0, 0.0])
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
1.0
决策树模型
classpyspark.ml.classification.DecisionTreeClassificationModel(java_model)
模型是通过决策树分类器训练获得
主要的内部成员函数是transform,负责对测试样本预测分类,返回结果是dataframe格式
transform(dataset, params=None)
Transforms the input dataset with optional parameters.
Parameters:
dataset – input dataset, which is an instance of pyspark.sql.DataFrame
params – an optional param map that overrides embedded params.
Returns:
transformed dataset
New in version 1.3.0.