Scala的模式匹配,可以对:值、类型、集合、Option、case class 、函数等进行匹配。
首先来鉴赏Spark源代码中的模式匹配示例:
val classNameV1_0 = "org.apache.spark.mllib.classification.LogisticRegressionModel"
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata)
val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0)
// numFeatures, numClasses, weights are checked in model initialization
val model =
new LogisticRegressionModel(data.weights, data.intercept, numFeatures, numClasses)
data.threshold match {
case Some(t) => model.setThreshold(t)
case None => model.clearThreshold()
}
model
case _ => throw new Exception(
s"LogisticRegressionModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $version). Supported:\n" +
s" ($classNameV1_0, 1.0)")
}
以上示例了一个模式匹配,传入一个Tuple类型,进行模式匹配,而且模式匹配中还嵌套了一个模式匹配。
每个case里面都是一个函数体, 通过=>符号来表示,左边是输入,右边是函数体。
下面详细讲解各种模式匹配情况。
对值进行模式匹配:
scala> def getSalary(name: String, age: Int) {
| name match {
| case "Spark" => println(name + ":$150000/year")
| case "Hadoop" => println(name + ":$100000/year")
| case _ if name == "Scala" => println(name + ":$140000/year")
| case _ if name == "C" => println(name + ":$90000/year")
| case bbb if age >= 5 => println("name: " + bbb + ", age: " + age + ",$120000/year")
| case _ => println(name + ":$80000/year")
| }
| }
getSalary: (name: String, age: Int)Unit
scala> getSalary("Java",10)
name: Java, age: 10,$120000/year
scala> getSalary("Java",2)
Java:$80000/year
scala> getSalary("Spark",2)
Spark:$150000/year
scala> getSalary("Hadoop",2)
Hadoop:$100000/year
scala> getSalary("Scala",2)
Scala:$140000/year
scala> getSalary("C",2)
C:$90000/year
scala> getSalary("R",2)
R:$80000/year
解释:
case bbb if age >= 5 这句代码中的bbb变量值为传入参数name。
对类型进行模式匹配:
scala> def getMatchType(msg: Any) {
| msg match {
| case i: Int => println("Integer")
| case s: String => println("String")
| case d: Double => println("Double")
| case array: Array[Int] => println("Array")
| case _ => println("Unkown type")
| }
| }
getMatchType: (msg: Any)Unit
scala> getMatchType(10)
Integer
scala> getMatchType(10.000)
Double
scala> getMatchType("Good!")
String
scala> getMatchType(Array(1,2,3))
Array
scala> getMatchType(true)
Unkown type
对集合进行模式匹配:
scala> def getMatchCollection(msg: Array[String]) {
| msg match {
| case Array("Scala") => println("One element")
| case Array("Scala", "Java") => println("Two elements")
| case Array("Spark", _*) => println("Many elements begins with Spark")
| case _ => println("Unkown type")
| }
| }
getMatchCollection: (msg: Array[String])Unit
scala> getMatchCollection(Array("Scala", "Java"))
Two elements
scala> getMatchCollection(Array("Scala", "Java","aa"))
Unkown type
scala> getMatchCollection(Array("Spark","Scala", "Java","aa"))
Many elements begins with Spark
scala> getMatchCollection(Array("Spark"))
Many elements begins with Spark
scala> getMatchCollection(Array("Scala"))
One element
scala> getMatchCollection(Array("Java"))
Unkown type
对 case class进行模式匹配:
scala> class DataFrameWork
defined class DataFrameWork
scala> case class ComputationFramework(name: String, popular: Boolean) extends DataFrameWork
defined class ComputationFramework
scala> case class StorageFramework(name: String, popular: Boolean) extends DataFrameWork
defined class StorageFramework
scala> def getBigDataType(data: DataFrameWork) {
| data match {
| case ComputationFramework(name, popular) =>
| println("ComputationFramework :" + "name : " + name + " , popular :" + popular)
| case StorageFramework(name, popular) =>
| println("StorageFramework :" + "name : " + name + " , popular :" + popular)
| case _ => println("Some other type")
| }
| }
getBigDataType: (data: DataFrameWork)Unit
scala> getBigDataType(ComputationFramework("Spark", true))
ComputationFramework :name : Spark , popular :true
scala> getBigDataType(ComputationFramework("Spark", false))
ComputationFramework :name : Spark , popular :false
scala> getBigDataType(StorageFramework("HDFS", true))
StorageFramework :name : HDFS , popular :true
scala> getBigDataType(null)
Some other type
scala> getBigDataType(new StorageFramework("HDFS", true))
StorageFramework :name : HDFS , popular :true
注意在:ComputationFramework(“Spark”, true)这行代码中,直接写类名然后传递值,
case class默认有apply类的工厂方法,因为case class也是类,scala编译器会自动生成类的伴生对象,自动生成apply方法。
对Option进行模式匹配:
如果值存在, Option[T] 就是一个 Some[T] ,如果不存在, Option[T] 就是对象 None ,建议都用这种方式取值。
scala> def getValue(key: String, content: Map[String, String]) {
| content.get(key) match {
| case Some(value) => println(value)
| case None => println("Not Found!!!")
| }
| }
getValue: (key: String, content: Map[String,String])Unit
scala> getValue("Spark",Map("Spark"->"100","Hadoop"->"90"))
100
scala> getValue("scala",Map("Spark"->"100","Hadoop"->"90"))
Not Found!!!
scala> getValue("Hadoop",Map("Spark"->"100","Hadoop"->"90"))
90
本博声明:
博文内容源自DT大数据梦工厂大数据Spark“蘑菇云”前置课程。相关课程内容视频可以参考:
百度网盘链接:http://pan.baidu.com/s/1cFqjQu(如果链接失效或需要后续的更多资源,请联系QQ460507491或者微信号:DT1219477246 获取上述资料,或者直接拨打 18610086859咨询)。