这个例子来源于scala圣经级教程《Functional Programming in Scala》,虽然原书的随书代码可以找到这些类的影子,但却没有关于如何使用这些类的示例代码,本人在阅读此书期间,除了跟着书中的代码敲了一遍之外,还写了一些测试代码进行验证,贴出来作为blog主要是为了方便自己,也为那些同样在阅读此书的人参考, 因为或许有可能你看了书本也不知道如何使用这些现成的函数和库。注释不多,本人也不可能写太多,因为这本书不是简单的入门书,而是一本进阶,升华用的内涵书籍,三言两语也解释不清其中的很多细节,如果看不懂,但是又感兴趣的就去看原书吧……
Scala推崇使用纯函数 进行编程(纯函数满足引用透明性,因此也更加模块化,更加可测试,这里给一个更抽象的引用透明性的定义: 如果说在程序p中每个出现的表达式e都被其求值的结果给替换,且对p不造成任何影响,我们可以说e对于p而言都是引用透明的))。但并没有说纯函数式编程是不允许使用可变状态的。本章的例子展示了如何在Scala中使用局部可变状态,但又保持一个纯函数式的外壳。示例同时也展示了一些借助Scala类型系统来保证其纯粹性的技巧。
package localeffects
object Mutable {
//java风格的或者说命令式程序风格的快速排序
def quicksort(xs: List[Int]): List[Int] = if (xs.isEmpty) xs else {
val arr = xs.toArray
def swap(x: Int, y: Int) = {
val tmp = arr(x)
arr(x) = arr(y)
arr(y) = tmp
}
def partition(l: Int, r: Int, pivot: Int) = {
val pivotVal = arr(pivot)
swap(pivot, r)
var j = l
for (i <- l until r) if (arr(i) < pivotVal) {
swap(i, j)
j += 1
}
swap(j, r)
j
}
def qs(l: Int, r: Int): Unit = if (l < r) {
val pi = partition(l, r, l + (r - l) / 2)
qs(l, pi - 1)
qs(pi + 1, r)
}
qs(0, arr.length - 1)
arr.toList
}
}
sealed trait ST[S,A] { self =>
protected def run(s: S): (A,S)
def map[B](f: A => B): ST[S,B] = new ST[S,B] {
def run(s: S) = {
val (a, s1) = self.run(s)
(f(a), s1)
}
}
def flatMap[B](f: A => ST[S,B]): ST[S,B] = new ST[S,B] {
def run(s: S) = {
val (a, s1) = self.run(s)
f(a).run(s1)
}
}
}
object ST {
def apply[S,A](a: => A) = {
lazy val memo = a
new ST[S,A] {
def run(s: S) = (memo, s)
}
}
def runST[A](st: RunnableST[A]): A =
st[Unit].run(())._1
}
sealed trait STRef[S,A] {
protected var cell: A
def read: ST[S,A] = ST(cell)
def write(a: => A): ST[S,Unit] = new ST[S,Unit] {
def run(s: S) = {
cell = a
((), s)
}
}
}
object STRef {
def apply[S,A](a: A): ST[S, STRef[S,A]] = ST(new STRef[S,A] {
var cell = a
})
}
trait RunnableST[A] {
def apply[S]: ST[S,A]
}
// Scala requires an implicit Manifest for constructing arrays.
sealed abstract class STArray[S,A](implicit manifest: Manifest[A]) {
protected def value: Array[A]
def size: ST[S,Int] = ST(value.size)
// Write a value at the give index of the array
def write(i: Int, a: A): ST[S,Unit] = new ST[S,Unit] {
def run(s: S) = {
value(i) = a
((), s)
}
}
// Read the value at the given index of the array
def read(i: Int): ST[S,A] = ST(value(i))
// Turn the array into an immutable list
def freeze: ST[S,List[A]] = ST(value.toList)
def fill(xs: Map[Int,A]): ST[S,Unit] =
xs.foldRight(ST[S,Unit](())) {
case ((k, v), st) => st flatMap (_ => write(k, v))
}
def swap(i: Int, j: Int): ST[S,Unit] = for {
x <- read(i)
y <- read(j)
_ <- write(i, y)
_ <- write(j, x)
} yield ()
}
object STArray {
// Construct an array of the given size filled with the value v
def apply[S,A:Manifest](sz: Int, v: A): ST[S, STArray[S,A]] =
ST(new STArray[S,A] {
lazy val value = Array.fill(sz)(v)
})
def fromList[S,A:Manifest](xs: List[A]): ST[S, STArray[S,A]] =
ST(new STArray[S,A] {
lazy val value = xs.toArray
})
}
object Immutable {
def noop[S] = ST[S,Unit](())
def partition[S](a: STArray[S,Int], l: Int, r: Int, pivot: Int): ST[S,Int] = for {
vp <- a.read(pivot)
_ <- a.swap(pivot, r)
j <- STRef(l)
_ <- (l until r).foldLeft(noop[S])((s, i) => for {
_ <- s
vi <- a.read(i)
_ <- if (vi < vp) (for {
vj <- j.read
_ <- a.swap(i, vj)
_ <- j.write(vj + 1)
} yield ()) else noop[S]
} yield ())
x <- j.read
_ <- a.swap(x, r)
} yield x
def qs[S](a: STArray[S,Int], l: Int, r: Int): ST[S, Unit] = if (l < r) for {
pi <- partition(a, l, r, l + (r - l) / 2)
_ <- qs(a, l, pi - 1)
_ <- qs(a, pi + 1, r)
} yield () else noop[S]
//scala风格的或者说纯函数式风格的快速排序
def quicksort(xs: List[Int]): List[Int] =
if (xs.isEmpty) xs else ST.runST(new RunnableST[List[Int]] {
def apply[S] = for {
arr <- STArray.fromList(xs)
size <- arr.size
_ <- qs(arr, 0, size - 1)
sorted <- arr.freeze
} yield sorted
})
}
import scala.collection.mutable.HashMap
sealed trait STMap[S,K,V] {
protected def table: HashMap[K,V]
def size: ST[S,Int] = ST(table.size)
// Get the value under a key
def apply(k: K): ST[S,V] = ST(table(k))
// Get the value under a key, or None if the key does not exist
def get(k: K): ST[S, Option[V]] = ST(table.get(k))
// Add a value under a key
def +=(kv: (K, V)): ST[S,Unit] = ST(table += kv)
// Remove a key
def -=(k: K): ST[S,Unit] = ST(table -= k)
}
object STMap {
def empty[S,K,V]: ST[S, STMap[S,K,V]] = ST(new STMap[S,K,V] {
val table = HashMap.empty[K,V]
})
def fromMap[S,K,V](m: Map[K,V]): ST[S, STMap[S,K,V]] = ST(new STMap[S,K,V] {
val table = (HashMap.newBuilder[K,V] ++= m).result
})
}
object Test {
def main(args: Array[String]): Unit = {
import testing.Gen
import state.RNG.Simple
val rng = Simple(System.currentTimeMillis())
val chooseGen = Gen.choose(1, 101)
val size = 40
//生成长度为size、元素大小在闭区间[1, 100]的链表
val li = Gen.listOf(chooseGen).g(size).sample.run(rng)._1
println(li)
val sorted = Mutable.quicksort(li)//java风格的或者说命令式程序风格的快速排序
println(sorted)
val res = ST.runST(new RunnableST[Int] {
def apply[Unit]: ST[Unit, Int] = ST(10)
})
println(res)
println(Immutable.quicksort(li))//scala风格的或者说纯函数式风格的快速排序
}
}
上述代码的运行结果是:
List(36, 62, 62, 12, 43, 42, 64, 96, 48, 9)
List(9, 12, 36, 42, 43, 48, 62, 62, 64, 96)
10
List(9, 12, 36, 42, 43, 48, 62, 62, 64, 96)