Spark MLlib 学习入门笔记 - RDD基础

RDD(Resilient Distributed Datasets)分布式弹性数据集,将数据分布存储在不同节点的计算机内存中进行存储和处理。RDD的任务被分成两部分:Transformation和Action。Transformation用于对RDD的创建,即一个RDD转换为另一个RDD,Action是数据计算执行部分,如count、reduce、collect等。 Spark文档里有相关的说明,网上还有一个Spark文档的中文翻译,可以参考。从编程的角度来看,一开始把RDD当成一个数组,并记住它的运行任务由Transformation算子和Action算子共同完成,得到运算结果就好了。

以下是我在学习RDD时完成的一道练习题,通过这道练习,可以掌握RDD编程的基本思路和方法。

对机器学习数据iris.data数据集进行简单的处理,包括filter、count、distince和分类统计(求和、最大值、最小值和平均值)。

1. 数据说明


[plain] view plain copy

  1. 5.1,3.5,1.4,0.2,Iris-setosa  

  2. 4.9,3.0,1.4,0.2,Iris-setosa  

  3. 4.7,3.2,1.3,0.2,Iris-setosa  

  4. 4.6,3.1,1.5,0.2,Iris-setosa  

  5. 5.0,3.6,1.4,0.2,Iris-setosa  

第一个数据septal length;第二个数据sepal width;第三个数据petal length;第四个数据petal width;第五个数据class标签;在程序中用name表示,相当于key。


2.源代码

Iris.scala


[plain] view plain copy

  1. package basic.iris  

  2.   

  3. /**  

  4.   * Created by Oliver on 2017/5/18.  

  5.   */  

  6. //1. sepal length in cm  

  7. //2. sepal width in cm  

  8. //3. petal length in cm  

  9. //4. petal width in cm  

  10. //5. class:  

  11.   

  12. class Iris extends java.io.Serializable {  

  13.   var sl: Double = 0  

  14.   var sw: Double = 0  

  15.   var pl: Double = 0  

  16.   var pw: Double = 0  

  17.   

  18.   var sl_max: Double = 0  

  19.   var sw_max: Double = 0  

  20.   var pl_max: Double = 0  

  21.   var pw_max: Double = 0  

  22.   

  23.   var sl_min: Double = 0  

  24.   var sw_min: Double = 0  

  25.   var pl_min: Double = 0  

  26.   var pw_min: Double = 0  

  27.   

  28.   var name: String =""  

  29.   var count: Int = 1  

  30.   

  31.   override def toString : String = {  

  32.     var s =  "(" + name + ", "  

  33.     s = s + "[ " + "%.1f".format(sl) + ", " + "%.1f".format(sl_min) + ", " + "%.1f".format(sl_max) + ", " + "%.1f".format(sl/count) + " ], "  

  34.     s = s + "[ " + "%.1f".format(sw) + ", " + "%.1f".format(sw_min) + ", " + "%.1f".format(sw_max) + ", " + "%.1f".format(sw/count) + " ], "  

  35.     s = s + "[ " + "%.1f".format(pl) + ", " + "%.1f".format(pl_min) + ", " + "%.1f".format(pl_max) + ", " + "%.1f".format(pl/count) + " ], "  

  36.     s = s + "[ " + "%.1f".format(pw) + ", " + "%.1f".format(pw_min) + ", " + "%.1f".format(pw_max) + ", " + "%.1f".format(pw/count) + " ], "  

  37.     s = s + count.toString + ")"  

  38.   

  39.     return s  

  40.   }  

  41. }  

IrisStat.scala



[plain] view plain copy

  1. package basic.iris  

  2.   

  3. import org.apache.spark.{SparkConf, SparkContext}  

  4.   

  5. /**  

  6.   * Created by Oliver on 2017/5/18.  

  7.   */  

  8. //local  

  9. //E:/MyProject/SparkDiscover/data/iris.data  

  10. object IrisStat {  

  11.   

  12.   def isValid(line: String): Boolean = {  

  13.     val parts = line.split(",")  

  14.     return parts.length == 5  

  15.   }  

  16.   

  17.   def parseLine(line: String): (String, Iris) = {  

  18.     val parts = line.split(",")  

  19.     val iris = new Iris  

  20.     iris.sl = parts(0).toDouble  

  21.     iris.sw = parts(1).toDouble  

  22.     iris.pl = parts(2).toDouble  

  23.     iris.pw = parts(3).toDouble  

  24.   

  25.     iris.sl_min = parts(0).toDouble  

  26.     iris.sw_min = parts(1).toDouble  

  27.     iris.pl_min = parts(2).toDouble  

  28.     iris.pw_min = parts(3).toDouble  

  29.   

  30.     iris.sl_max = parts(0).toDouble  

  31.     iris.sw_max = parts(1).toDouble  

  32.     iris.pl_max = parts(2).toDouble  

  33.     iris.pw_max = parts(3).toDouble  

  34.     iris.name = parts(4)  

  35.   

  36.     return (iris.name, iris)  

  37.   }  

  38.   

  39.   def add(a: Iris, b: Iris): Iris = {  

  40.     val c = new Iris  

  41.     c.sl = a.sl + b.sl  

  42.     c.sw = a.sw + b.sw  

  43.     c.pl = a.pl + b.pl  

  44.     c.pw = a.pw + b.pw  

  45.     c.count = a.count + b.count  

  46.     c.name = a.name  

  47.   

  48.     //比较大小  

  49.     c.sl_max = math.max(a.sl_max, b.sl_max)  

  50.     c.sw_max = math.max(a.sw_max, b.sw_max)  

  51.     c.pl_max = math.max(a.pl_max, b.pl_max)  

  52.     c.pw_max = math.max(a.pw_max, b.pw_max)  

  53.   

  54.     c.sl_min = math.min(a.sl_min, b.sl_min)  

  55.     c.sw_min = math.min(a.sw_min, b.sw_min)  

  56.     c.pl_min = math.min(a.pl_min, b.pl_min)  

  57.     c.pw_min = math.min(a.pw_min, b.pw_min)  

  58.   

  59.     return c  

  60.   }  

  61.   

  62.   def printResult(res: (String, Iris)){  

  63.     println(res._2)  

  64.   }  

  65.   

  66.   def main(args: Array[String]){  

  67.     val conf = new SparkConf().setMaster(args(0)).setAppName("Iris")  

  68.     val sc = new SparkContext(conf)  

  69.     val data = sc.textFile(args(1)).filter(isValid(_))  

  70.   

  71.     // distinct  

  72.     println("---1---------------------------")  

  73.     data.map(_.split(",")(4)).distinct().foreach(println)  

  74.   

  75.     // 简单计数  

  76.     val c = data.count()  

  77.     val c_setosa = data.filter( "Iris-setosa" == _.split(",")(4) ).count()  

  78.     val c_versicolor = data.filter( "Iris-versicolor" == _.split(",")(4) ).count()  

  79.     val c_virginica = data.filter( "Iris-virginica" == _.split(",")(4) ).count()  

  80.   

  81.     println("")  

  82.     println("---2---------------------------")  

  83.     println(c, c_setosa, c_versicolor, c_virginica)  

  84.   

  85.   

  86.     // mapreduce 分组求和、求平均、求最大最小  

  87.     //data.map(parseLine(_)).foreach(println)  

  88.     println("")  

  89.     println("---3---------------------------")  

  90.     data.map(parseLine(_)).reduceByKey(add(_,_)).collect().foreach(printResult)  

  91.   

  92.   }  

  93. }  


3.运行配置

0?wx_fmt=png

4. 运行结果


[plain] view plain copy

  1. ---1---------------------------  

  2. Iris-setosa  

  3. Iris-versicolor  

  4. Iris-virginica  

  5.   

  6. ---2---------------------------  

  7. (150,50,50,50)  

  8.   

  9. ---3---------------------------  

  10. (Iris-setosa, [ 250.3, 4.3, 5.8, 5.0 ], [ 170.9, 2.3, 4.4, 3.4 ], [ 73.2, 1.0, 1.9, 1.5 ], [ 12.2, 0.1, 0.6, 0.2 ], 50)  

  11. (Iris-versicolor, [ 296.8, 4.9, 7.0, 5.9 ], [ 138.5, 2.0, 3.4, 2.8 ], [ 213.0, 3.0, 5.1, 4.3 ], [ 66.3, 1.0, 1.8, 1.3 ], 50)  

  12. (Iris-virginica, [ 329.4, 4.9, 7.9, 6.6 ], [ 148.7, 2.2, 3.8, 3.0 ], [ 277.6, 4.5, 6.9, 5.6 ], [ 101.3, 1.4, 2.5, 2.0 ], 50)  

5.代码说明


1)Iris.scala

Iris类存放数据文件的5个字段,当count为1是单个对象的数据,即数据文件的一行。当count大于1时,表示对多行数据的对应字段进行求和运算后的结果,count行计数,_min和_max字段存放的是对应字段的最小值和最大值。重载了toString输出结果。

2)IrisStat.scala 

isValid用于判断数据是否是有效数据,无效则抛弃,完成一个RDD转换,相当于数据清洗。

parseLine解析数据,将字符串转换为对象,方便后续处理。

add执行运算,包括求和、计数、最大值和最小值,平均值在输出是用“和”除以“计数”就得到了。

main函数说明:

(1) data = sc.textFile(args(1)).filter(isValid(_)) 过滤无效数据,如果不过滤,只要数据集中有一行错误数据,程序就会出错,如混入一个空行。

(2) distinct ,map是用split得到第5个字段,然后对第5个字段应用distince就可以了。

(3) filter,输入过滤条件就可以了,然后在调用count计数。

(4)分组求和、最小值和最大值,调研map传人parseLine函数转换得到一个新数据集,这个数据集是(String, Iris)的<key, value>形式,再调用reduceByKey就实现了分组操作;在reduceByKey中传入add执行所定义的运行,调用collect返回数据集,用foreach打印结果。 

(5)注意下划线,理解为要传入的数据就好了,如filter(isValid(_))中,filter把上一个map的数据通过_传给isValid。

0?wx_fmt=gif


没有更多推荐了,返回首页