Spark2 线性回归


  1. Spark session available as 'spark'.
  2. Welcome to Spark version 2.0.1

  3. import org.apache.spark.sql.SparkSession
  4. import org.apache.spark.sql.Dataset
  5. import org.apache.spark.sql.Row
  6. import org.apache.spark.sql.DataFrame
  7. import org.apache.spark.sql.Column
  8. import org.apache.spark.sql.DataFrameReader
  9. import org.apache.spark.rdd.RDD
  10. import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
  11. import org.apache.spark.sql.Encoder
  12. import org.apache.spark.ml.linalg.Vectors
  13. import org.apache.spark.ml.feature.VectorAssembler
  14. import org.apache.spark.ml.regression.LinearRegression


  15. scala> val spark = SparkSession.builder().appName("Spark SQL basic example").config("spark.some.config.option", "some-value").getOrCreate()
  16. 16/11/05 15:07:06 WARN SparkSession$Builder: Use an existing SparkSession, some configuration may not take effect.
  17. spark: org.apache.spark.sql.SparkSession = org.apache.spark.sql.SparkSession@3300811b

  18. scala>

  19. scala> // For implicit conversions like converting RDDs to DataFrames

  20. scala> import spark.implicits._
  21. import spark.implicits._

  22. scala>

  23. scala> // Load training data

  24. scala> val data: DataFrame = spark.read.format("csv").option("header", true).load("hdfs://ns1/datafile/wfp.csv")
  25. data: org.apache.spark.sql.DataFrame = [windSpeed: string, power: string]

  26. scala>

  27. scala> data.cache()
  28. 16/11/05 15:07:12 WARN CacheManager: Asked to cache already cached data.
  29. res21: data.type = [windSpeed: string, power: string]

  30. scala>

  31. scala> data.limit(10).show
  32. 16/11/05 15:07:13 WARN Executor: 1 block locks were not released by TID = 352:
  33. [rdd_9_0]
  34. +---------+-----+
  35. |windSpeed|power|
  36. +---------+-----+
  37. | 3        |  20|
  38. | 3.5      |  30|
  39. | 4        |  50|
  40. | 4.5      | 100|
  41. | 5        | 200|
  42. | 5.5      | 300|
  43. | 6        | 400|
  44. | 6.5      | 500|
  45. | 7        | 600|
  46. | 7.5      | 700|
  47. +---------+-----+


  48. scala>

  49. scala> // 字段类型转换,并定义别名

  50. scala> val data1= data.select( ( data("windSpeed").cast("Double")*data("windSpeed").cast("Double")*data("windSpeed").cast("Double") ).as("windSpeed"),data("power").cast("Double") )
  51. data1: org.apache.spark.sql.DataFrame = [windSpeed: double, power: double]

  52. scala>

  53. scala> data1.limit(10).show
  54. 16/11/05 15:07:16 WARN Executor: 1 block locks were not released by TID = 353:
  55. [rdd_9_0]
  56. +---------+-----+
  57. |windSpeed|power|
  58. +---------+-----+
  59. | 27.0    | 20.0|
  60. | 42.875  | 30.0|
  61. | 64.0    | 50.0|
  62. | 91.125  |100.0|
  63. | 125.0   |200.0|
  64. | 166.375 |300.0|
  65. | 216.0   |400.0|
  66. | 274.625 |500.0|
  67. | 343.0   |600.0|
  68. | 421.875 |700.0|
  69. +---------+-----+


  70. scala>

  71. scala> val data2=data1.filter("power>20 and windSpeed<3500").orderBy("windSpeed")
  72. data2: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [windSpeed: double, power: double]

  73. scala> data2.limit(10).show
  74. +---------+-----+
  75. |windSpeed|power|
  76. +---------+-----+
  77. | 42.875  | 30.0|
  78. | 64.0    | 50.0|
  79. | 91.125  |100.0|
  80. | 125.0   |200.0|
  81. | 166.375 |300.0|
  82. | 216.0   |400.0|
  83. | 274.625 |500.0|
  84. | 343.0   |600.0|
  85. | 421.875 |700.0|
  86. | 512.0   |800.0|
  87. +---------+-----+


  88. scala>
  89.      | // 转换成Label和Features格式

  90. scala> val assembler = new VectorAssembler().setInputCols(Array("windSpeed")).setOutputCol("features")
  91. assembler: org.apache.spark.ml.feature.VectorAssembler = vecAssembler_6f4ca7e2549a

  92. scala>

  93. scala> val output: DataFrame = assembler.transform(data2)
  94. output: org.apache.spark.sql.DataFrame = [windSpeed: double, power: double ... 1 more field]

  95. scala>

  96. scala> output.printSchema()
  97. root
  98.  |-- windSpeed: double (nullable = true)
  99.  |-- power: double (nullable = true)
  100.  |-- features: vector (nullable = true)


  101. scala>

  102. scala> output.limit(10).show
  103. +---------+-----+---------+
  104. |windSpeed|power| features|
  105. +---------+-----+---------+
  106. | 42.875  | 30.0| [42.875]|
  107. | 64.0    | 50.0| [64.0]  |
  108. | 91.125  |100.0| [91.125]|
  109. | 125.0   |200.0| [125.0|
  110. | 166.375 |300.0[166.375]|
  111. | 216.0   |400.0| [216.0|
  112. | 274.625 |500.0[274.625]|
  113. | 343.0   |600.0| [343.0|
  114. | 421.875 |700.0[421.875]|
  115. | 512.0   |800.0| [512.0|
  116. +---------+-----+---------+


  117. scala>

  118. scala> val training = output
  119. training: org.apache.spark.sql.DataFrame = [windSpeed: double, power: double ... 1 more field]

  120. scala>

  121. scala> training.cache()
  122. res28: training.type = [windSpeed: double, power: double ... 1 more field]

  123. scala>

  124. scala> // 设置线性回归参数

  125. scala> val lr = new LinearRegression().setLabelCol("power").setFeaturesCol("features").setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8).setFitIntercept(true)
  126. lr: org.apache.spark.ml.regression.LinearRegression = linReg_58dde0c17920

  127. scala>

  128. scala> // Fit the model

  129. scala> val lrModel = lr.fit(training)
  130. lrModel: org.apache.spark.ml.regression.LinearRegressionModel = linReg_58dde0c17920

  131. scala>

  132. scala> // Print the coefficients and intercept for linear regression

  133. scala> println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")
  134. Coefficients: [0.5113433522323718] Intercept: 389.0639900098431

  135. scala>

  136. scala> // Summarize the model over the training set and print out some metrics

  137. scala> val trainingSummary = lrModel.summary
  138. trainingSummary: org.apache.spark.ml.regression.LinearRegressionTrainingSummary = org.apache.spark.ml.regression.LinearRegressionTrainingSummary@3ee2b229

  139. scala> println(s"numIterations: ${trainingSummary.totalIterations}")
  140. numIterations: 4

  141. scala> println(s"objectiveHistory: ${trainingSummary.objectiveHistory.toList}")
  142. objectiveHistory: List(0.4791666666666666, 0.3927565804317654, 0.1589356484493222, 0.08099533778850757)

  143. scala> trainingSummary.residuals.show()
  144. +-------------------+
  145. | residuals|
  146. +-------------------+
  147. |-380.98783623680606|
  148. | -371.7899645527149|
  149. |-335.66015298201796|
  150. |-252.98190903888957|
  151. |-174.13874023750395|
  152. | -99.51415409203543|
  153. |-29.491658116658186|
  154. | 35.5452401744534  |
  155. | 95.21303326712507 |
  156. | 149.12821364718252|
  157. | 196.90727380045155|
  158. | 238.16670621275784|
  159. | 272.5230033699271 |
  160. | 299.592657757785  |
  161. | 318.99216186215745|
  162. | 280.33800816887015|
  163. | 233.24668916374844|
  164. | 177.33469733261836|
  165. | 112.2185251613057 |
  166. | 37.51466513563605 |
  167. +-------------------+
  168. only showing top 20 rows


  169. scala> println(s"RMSE: ${trainingSummary.rootMeanSquaredError}")
  170. RMSE: 231.88264620805631

  171. scala> println(s"r2: ${trainingSummary.r2}")
  172. r2: 0.8318466871093443


来自 “ ITPUB博客 ” ,链接:http://blog.itpub.net/29070860/viewspace-2127853/,如需转载,请注明出处,否则将追究法律责任。

转载于:http://blog.itpub.net/29070860/viewspace-2127853/

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值