spark Mllib 决策树模型训练与效果检测

1.导包:

 

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
	xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
	<modelVersion>4.0.0</modelVersion>
	<parent>
		<groupId>org.springframework.boot</groupId>
		<artifactId>spring-boot-starter-parent</artifactId>
		<version>2.4.2</version>
		<relativePath/> <!-- lookup parent from repository -->
	</parent>
	<groupId>com.siao</groupId>
	<artifactId>sparkml</artifactId>
	<version>0.0.1-SNAPSHOT</version>
	<name>sparkml</name>
	<description>Demo project for Spring Boot</description>
	<properties>
		<java.version>1.8</java.version>
		<spark.version>2.2.0</spark.version>
	</properties>
	<dependencies>
		<dependency>
			<groupId>org.springframework.boot</groupId>
			<artifactId>spring-boot-starter-web</artifactId>
		</dependency>

		<dependency>
			<groupId>org.springframework.boot</groupId>
			<artifactId>spring-boot-starter-test</artifactId>
			<scope>test</scope>
		</dependency>

		<dependency>
			<groupId>org.apache.spark</groupId>
			<artifactId>spark-core_2.11</artifactId>
			<version>${spark.version}</version>
			<exclusions>
				<exclusion>
					<groupId>org.slf4j</groupId>
					<artifactId>slf4j-log4j12</artifactId>
				</exclusion>
			</exclusions>
		</dependency>

		<!-- https://mvnrepository.com/artifact/io.netty/netty-all -->
		<dependency>
			<groupId>io.netty</groupId>
			<artifactId>netty-all</artifactId>
			<version>4.1.11.Final</version>
		</dependency>

		<dependency>
			<groupId>org.apache.spark</groupId>
			<artifactId>spark-sql_2.11</artifactId>
			<version>${spark.version}</version>
		</dependency>
		<dependency>
			<groupId>org.codehaus.janino</groupId>
			<artifactId>janino</artifactId>
			<version>3.0.8</version>
		</dependency>
		<dependency>
			<groupId>org.apache.spark</groupId>
			<artifactId>spark-mllib_2.11</artifactId>
			<version>${spark.version}</version>
		</dependency>

		<dependency>
			<groupId>com.google.guava</groupId>
			<artifactId>guava</artifactId>
			<version>14.0.1</version>
		</dependency>
	</dependencies>

	<build>
		<plugins>
			<plugin>
				<groupId>org.springframework.boot</groupId>
				<artifactId>spring-boot-maven-plugin</artifactId>
			</plugin>
		</plugins>
	</build>

</project>

2.修改版本为scala2.11

注意这里采用的全是scala2.11

Scala2.11的修改方法

 

3.创建sparkSeesion

​
//创建sparkSession
    val sparkSession1 = SparkSession.builder()
      .appName("SparkSQLDemo")
      .master("local")
        .config("spark.testing.memory","2147480000")
      .getOrCreate()

​

4.获取hdfs中的数据

//获取hdfs中的数据
    val dataWithoutHeader = sparkSession1.read.
      option("inferSchema", true).
      option("header", false).
      csv("hdfs://192.168.48.101:8020/user/ds/covtype.data")

这个covtype.data的数据集,csdn分享要钱,请加我qq群,公益分享给大家139878017(计算机少林寺)

5.添加列名称

 val colNames = Seq(
      "Elevation", "Aspect", "Slope",
      "Horizontal_Distance_To_Hydrology", "Vertical_Distance_To_Hydrology",
      "Horizontal_Distance_To_Roadways",
      "Hillshade_9am", "Hillshade_Noon", "Hillshade_3pm",
      "Horizontal_Distance_To_Fire_Points"
    ) ++ (
      (0 until 4).map(i => s"Wilderness_Area_$i")
      ) ++ (
      (0 until 40).map(i => s"Soil_Type_$i")
      ) ++ Seq("Cover_Type")

6.

//转换为double格式
    val data = dataWithoutHeader.toDF(colNames:_*).
      withColumn("Cover_Type", col("Cover_Type").cast("double"))

//拆分为训练集和测试集
    val Array(trainData, testData) = data.randomSplit(Array(0.9, 0.1))
    trainData.cache()
    testData.cache()

//获取所有非预测值结果的特征
    val inputCols = trainData.columns.filter(_ != "Cover_Type")
    val assembler = new VectorAssembler().
      setInputCols(inputCols).
      setOutputCol("featureVector")

7.

//转换训练数据为向量
    val assembledTrainData = assembler.transform(trainData)


//创建决策树,并且设定随机种子
    val classifier = new DecisionTreeClassifier().
      setSeed(Random.nextLong()).
      setLabelCol("Cover_Type").
      setFeaturesCol("featureVector").
      setPredictionCol("prediction")


//查看影响的效果

model.featureImportances.toArray.zip(inputCols).
sorted.reverse.foreach(println)
 
 
 
8.管道模型存储
val pipeline = new Pipeline().setStages(Array(assembler, classifier))
    val model = pipeline.fit(trainData)
   model.write.overwrite.save("hdfs://192.168.48.101:8020/user/ds/models/linearRegression3")

9.管道数据加载

val model = PipelineModel.load("hdfs://192.168.48.101:8020/user/ds/models/linearRegression3")

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值