数据地址:http://archive.ics.uci.edu/ml/datasets/Wine
from pyspark.ml.classification import DecisionTreeClassificationModel
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml import Pipeline,PipelineModel
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.linalg import Vector,Vectors
from pyspark.sql import Row
from pyspark.ml.feature import IndexToString,StringIndexer,VectorIndexer
def getFeaAndLab(x):
res = {}
res['features'] = Vectors.dense(float(x[1]), float(x[2]),
float(x[3]), float(x[4]),
float(x[5]), float(x[6]),
float(x[7]), float(x[8]),
float(x[9]), float(x[10]),
float(x[11]), float(x[12]),
float(x[13]))
res['label'] = str(x[0])
return res
def model(data):
# ------------------------data procesing-------------------------
labelIndexer = StringIndexer(inputCol='label', outputCol='indexedLabel').fit(data)
featureIndexer = VectorIndexer(inputCol='features', outputCol='indexedFeatures').fit(data)
labelConverter = IndexToString(inputCol='prediction', outputCol='predictedLabel', labels=labelIndexer.labels)
# -----------------------choose your model--------------------------------
dtClassifier = DecisionTreeClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures")
# print("DecisionTree parameters:\n" + dtClassifier.explainParams()) # 参数解释
# ------------------------------pipeline-------------------------------------------------
dtPipeline = Pipeline().setStages([labelIndexer, featureIndexer, dtClassifier, labelConverter])
# ------------------------------split-------------------------------------------------
trainingData, testData = data.randomSplit([0.7, 0.3]) # 自带打乱功能
# ------------------------------Train-------------------------------------------------
dtPipelineModel = dtPipeline.fit(trainingData)
# ------------------------------test-------------------------------------------------
dtPredictions = dtPipelineModel.transform(testData)
# ------------------------------show-------------------------------------------------
dtPredictions.select("predictedLabel", "label", "features").show(20)
# ------------------------------evaluate-------------------------------------------------
evaluator = MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction")
dtAccuracy = evaluator.evaluate(dtPredictions)
print(dtAccuracy)
def main():
# 读取数据
rdd = spark.sparkContext.textFile('file:///usr/local/spark/mycode/ml/wine.txt')
# 将数据进行分割并转换为dataframe
data = rdd.map(lambda x: x.split(',')).map(lambda x: Row(**getFeaAndLab(x))).toDF()
model(data)
if __name__ == '__main__':
main()