这个例子来源于scala圣经级教程《Functional Programming in Scala》,由于本人跟着书中的代码敲了一遍,然后写了点测试代码验证了一下正确性,所以就放在这做个备忘吧。贴出来只是为了方便自己,如果看不懂,但是又感兴趣的就去看原书吧……
package state
trait RNG {
def nextInt: (Int, RNG)
}
object RNG {
case class Simple(seed: Long) extends RNG {
override def nextInt: (Int, RNG) = {
val newSeed = (seed * 0x5DEECE66DL + 0xBL) & 0xFFFFFFFFFFFFL
val nextRNG = Simple(newSeed)
val n = (newSeed >>> 16).toInt
(n, nextRNG)
}
}
def nonNegativeInt(rng: RNG): (Int, RNG) = {
val (i, r) = rng.nextInt
(if (i < 0) -(i + 1) else i, r)
}
/**
* We generate an integer >= 0 and divide it by one higher than the maximum. This is just one possible solution.
*
* @param rng
* @return
*/
def double(rng: RNG): (Double, RNG) = {
val (i, r) = nonNegativeInt(rng)
(i / Int.MaxValue.toDouble + 1, r)
}
def boolean(rng: RNG): (Boolean, RNG) = {
rng.nextInt match {
case (i, rng2) => (i % 2 == 0, rng2)
}
}
def intDouble(rng: RNG): ((Int, Double), RNG) = {
val (i, r1) = rng.nextInt
val (d, r2) = double(r1)
((i, d), r2)
}
def doubleInt(rng: RNG): ((Double, Int), RNG) = {
val ((i, d), r) = intDouble(rng)
((d, i), r)
}
def double3(rng: RNG): ((Double, Double, Double), RNG) = {
val (d1, r1) = double(rng)
val (d2, r2) = double(r1)
val (d3, r3) = double(r2)
((d1, d2, d3), r3)
}
// A simple recursive solution
def ints(count: Int)(rng: RNG): (List[Int], RNG) =
if (count == 0) (List(), rng) else {
val (x, r1) = rng.nextInt
val (xs, r2) = ints(count - 1)(r1)
(x :: xs, r2)
}
// A tail-recursive solution
def ints2(count: Int)(rng: RNG): (List[Int], RNG) = {
def go(count: Int, r: RNG, xs: List[Int]): (List[Int], RNG) =
if (count == 0)
(xs, r)
else {
val (x, r2) = r.nextInt
go(count - 1, r2, x :: xs)
}
go(count, rng, List())
}
type Rand[+A] = RNG => (A, RNG)
/**
* int是: val int: Rand[Int] = (rng: RNG) => rng.nextInt 的简写
*/
val int: Rand[Int] = _.nextInt
def unit[A](a: A): Rand[A] = rng => (a, rng)
def map[A, B](s: Rand[A])(f: A => B): Rand[B] = {
rng => {
val (a, rng2) = s(rng)
(f(a), rng2)
}
}
val _double: Rand[Double] = map(nonNegativeInt)(_ / (Int.MaxValue.toDouble + 1))
def map2[A, B, C](ra: Rand[A], rb: Rand[B])(f: (A, B) => C): Rand[C] = {
rng => {
val (a, rng1) = ra(rng)
val (b, rng2) = rb(rng1)
(f(a, b), rng2)
}
}
def both[A, B](ra: Rand[A], rb: Rand[B]): Rand[(A, B)] = map2(ra, rb)((_, _))
val randIntDouble: Rand[(Int, Double)] = both(int, double)
val randDoubleInt: Rand[(Double, Int)] = both(double, int)
def sequence[A](fs: List[Rand[A]]): Rand[List[A]] = fs.foldRight(unit(List[A]()))((f, acc) => map2(f, acc)(_ :: _))
def _ints(n: Int): Rand[List[Int]] = sequence(List.fill(n)(int))
def flatMap[A, B](f: Rand[A])(g: A => Rand[B]): Rand[B] = rng => {
val (a, r1) = f(rng)
g(a)(r1) // pass the new state along
}
def nonNegativeLessThan(n: Int): Rand[Int] = {
flatMap(nonNegativeInt) { i =>
val mod = i % n
// print for debug
// if (i + (n - 1) - mod >= 0) {print(i); print(" "); print(n); print(" "); print(mod); print(" "); unit(mod) } else nonNegativeLessThan(n)
if (i + (n - 1) - mod >= 0) unit(mod) else nonNegativeLessThan(n)
}
}
def _map[A, B](s: Rand[A])(f: A => B): Rand[B] = flatMap(s)(a => unit(f(a)))
def _map2[A, B, C](ra: Rand[A], rb: Rand[B])(f: (A, B) => C): Rand[C] = flatMap(ra)(a => map(rb)(b => f(a, b)))
def main(args: Array[String]): Unit = {
val now = System.currentTimeMillis();
val simple = Simple(now)
println(simple)
println(simple.nextInt)
val nonNegativeIntVal = RNG.nonNegativeInt(simple)
println(nonNegativeIntVal)
val doubleVal = RNG.double(simple)
println(doubleVal)
val booleanVal = RNG.boolean(simple)
println(booleanVal)
val intDoubleVal = RNG.intDouble(simple)
println(intDoubleVal)
val doubleIntVal = RNG.doubleInt(simple)
println(doubleIntVal)
val double3Val = RNG.double3(simple)
println(double3Val)
val intsVal = RNG.ints(3)(simple)
println(intsVal)
val ints2Val = RNG.ints2(3)(simple)
println(ints2Val)
println(RNG.int(simple))
println(simple)
println(RNG.unit(100)(simple)._1)
println(RNG.unit(100)(simple)._2)
println(RNG.unit(100)(simple)._2.nextInt)
println(RNG.map(RNG.int)(_ * 2)(simple))
val _doubleVal = RNG._double(simple)
println(_doubleVal)
val map2Val = RNG.map2(RNG.int, RNG.int)((_, _))(simple)
println(map2Val)
val bothVal = RNG.both(RNG.int, RNG.int)(simple)
println(bothVal)
val intRand5List = List.fill(5)(RNG.int)
val sequenceVal = RNG.sequence(intRand5List)(simple)
println(sequenceVal)
val _intsVal = RNG._ints(5)(simple)
println(_intsVal)
val nonNegativeLessThanVal = RNG.nonNegativeLessThan(Int.MaxValue)(simple)
println(nonNegativeLessThanVal)
}
}
上述代码的运行结果是:
Simple(1530788766246)
(-697541284,Simple(235760911150649))
(697541283,Simple(235760911150649))
(1.3248179719433273,Simple(235760911150649))
(true,Simple(235760911150649))
((-697541284,1.061915150872392),Simple(8713782830160))
((1.061915150872392,-697541284),Simple(8713782830160))
((1.3248179719433273,1.061915150872392,1.8075593983789717),Simple(167821095294491))
(List(-697541284, 132961774, -1734220603),Simple(167821095294491))
(List(-1734220603, 132961774, -697541284),Simple(167821095294491))
(-697541284,Simple(235760911150649))
Simple(1530788766246)
100
Simple(1530788766246)
(-697541284,Simple(235760911150649))
(-1395082568,Simple(235760911150649))
(0.32481797179207206,Simple(235760911150649))
((-697541284,132961774),Simple(8713782830160))
((-697541284,132961774),Simple(8713782830160))
(List(-697541284, 132961774, -1734220603, 630554181, -192781157),Simple(268840870823373))
(List(-697541284, 132961774, -1734220603, 630554181, -192781157),Simple(268840870823373))
(697541283,Simple(235760911150649))