用Scala的高阶函数让Chisel代码更优雅
动机
前面模块中讨厌的for
循环很冗长,破坏了函数式编程的目的,这一篇将会介绍高阶函数,让我们写生成器的时候更快。
FIR的卷积操作居然一行代码就能实现?
前面我们写了FIR过滤器的实现,其中的卷积部分是这么写的:
val muls = Wire(Vec(length, UInt(8.W)))
for(i <- 0 until length) {
if(i == 0) muls(i) := io.in * io.consts(i)
else muls(i) := regs(i - 1) * io.consts(i)
}
val scan = Wire(Vec(length, UInt(8.W)))
for(i <- 0 until length) {
if(i == 0) scan(i) := muls(i)
else scan(i) := muls(i) + scan(i - 1)
}
io.out := scan(length - 1)
简单说一下它的基本思想,首先把io.in
的每个元素与对应的io.const
相乘,然后存放到muls
中,然后muls
中的元素累加到scan
,其中scan(0) = muls(0)
,scan(1) = scan(0) + muls(1) = muls(0) + muls(1)
,一般化就是scan(n) = scan(n-1) + muls(n) = muls(0) + ... + muls(n-1) + muls(n)
,最后scan
的最后一个元素(等于muls
所有元素的值)被赋值到io.out
。
但是啊,真的很繁琐,罗里吧嗦半天就为了实现一个简单的操作。实际上呢,上面的操作可以写成一行代码:
io.out := (taps zip io.consts).map { case (a, b) => a * b }.reduce(_ + _)
它是怎么办到的呢?下面一点点来分析:
- 假设
taps
是所有采样的列表,也就是taps(0) = io.in
,taps(1) = regs(0)
等; (taps zip io.consts)
接受两个列表,taps
和io.const
,然后组合他们为一个列表,其每个元素都是输入的相应位置的元素的元组,具体来说,它的值就像这样:[(taps(0), io.consts(0)), (taps(1), io.consts(1)), ..., (taps(n), io.consts(n))]
。注意,因为.
是可以省略的,所以这个其实等价于(taps.zip(io.const))
;.map {case (a, b) => a*b}
应用了一个匿名函数,这个函数接受一个二元素的元组然后返回他们的乘积,这个匿名函数应用在了列表的元素上并返回结果也是列表,结果为[taps(0) * io.consts(0), taps(1) * io.consts(1), ..., taps(n) * io.consts(n)]
,和之前的muls
是等价的。现在先简单了解写匿名函数的语法,后面还会详细介绍;- 最后,
.reduce(_ + _)
在列表的元素上应用了函数(元素之间相加)。然而,这里同样接受两个参数,第一个是当前的累加结果,第二个是列表元素(第一次迭代中两个参数都是列表元素)。这是通过圆括号里面的两个下划线给定的。假设从左到右遍历,结果会是(((muls(0) + muls(1)) + muls(2)) + ...) + muls(n)
,最先计算的有更深的括号嵌套。这个结果就是卷积的输出。
作为参数的函数
正式地来讲,像map
和reduce
这样的函数就是高阶函数,因为它们是以函数为参数的函数。结果也证明,使用高阶函数能够省不少事,可以封装一个一般的计算模式,允许你在写代码是专注于应用整体的逻辑而不是控制流,能写出非常简洁的代码。
指定函数的不同方法
上面已经提到了两种了,这里总结一下:
- 对于每个元素只会引用一次的函数,可能可以使用下划线
_
来指代每个元素,在上面的例子中,reduce
的参数函数接受两个元素就可以被表示为_ + _
。虽然方便,但这受限于极少数满足一些复杂规则的情况,所以如果不行的话,可以试试下面的方法: - 显式指定输入参数列表。上面的规约操作可以被显式写成
(a, b) => a + b
,用的是把参数列表放在括号里的一般形式,然后跟着=>
符号,再跟着引用这些参数的函数体; - 如果需要解包元组的时候,就用
case
语句,就像case (a, b) => a + b
这里的用法一样。这种情况是接受了单个参数,即一个两元素的元组,然后将其解包为变量a
和b
,然后用于后面的函数体。
Scala中的实践
上上一篇文章中,我们介绍了Scala集合类API中的主要类,比如List
这种。这些高阶函数其实也是这些API的一部分,比如上面的map
和reduce
都是List
上的API。这一小节我们通过一些例子来熟悉这些方法。在例子中,简洁起见我们都在Scala的数字(Int
)上进行操作,但是因为Chisel的操作符也是类似的,所以这些概念可以泛化。
Map
List[A].map
有类型签名map[B](f: (A) ⇒ B): List[B]
。后面会有专门的一篇讲解关于类型的知识,现在就把这个类型A
和B
看成是Int
或者SInt
,意味着它们可以是软件类型也可以是硬件类型。
说人话就是,它接受一个类型为(f: (A) ⇒ B)
的参数,或者是一个接受两个参数的函数,第一个参数类型为A
,即和输入列表的元素类型一致,第二个参数类型就随便了,什么类型都可以,然后map
会返回一个类型B
的列表,即参数函数的返回值类型。
因为我们已经解释过了FIR例子中列表的行为,现在就直接看例子吧:
println(List(1, 2, 3, 4).map(x => x + 1)) // 函数中的显式参数列表
println(List(1, 2, 3, 4).map(_ + 1)) // 和上面的等价,但是隐式的
println(List(1, 2, 3, 4).map(_.toString + "a")) // 输出元素类型可和输入不同
println(List((1, 5), (2, 6), (3, 7), (4, 8)).map { case (x, y) => x*y }) // 用case解包元组,注意这里用的是大括号
// 提一嘴,Scala中有构造连续数字列表的语法
println(0 to 10) // to是inclusive的, 这里的10是包括在内的
println(0 until 10) // until是exclusive的,这里的10不包括在内
// 上面生成的和列表的行为基本一致,生成索引的时候很有用
val myList = List("a", "b", "c", "d")
println((0 until 4).map(myList(_)))
输出如下:
List(2, 3, 4, 5)
List(2, 3, 4, 5)
List(1a, 2a, 3a, 4a)
List(5, 12, 21, 32)
Range 0 to 10
Range 0 until 10
Vector(a, b, c, d)
再来个简单的练习,想要让列表中的每个元素的翻倍,???
处填什么代码呢?
println(List(1, 2, 3, 4).map(???))
显然,填_ * 2
就行了。
zipWithIndex
zipWithIndex
的类型签名是zipWithIndex: List[(A, Int)]
。
也就是说不接受任何参数,但是返回一个列表,其每个元素都是源数据和它的索引(第一个元素索引为0)。所以说,List("a", "b", "c", "d").zipWithIndex
会返回List(("a", 0), ("b", 1), ("c", 2), ("d", 3))
。
这在某些操作中,需要元素索引的场合特别有用。
这个也很简单,直接上例子:
println(List(1, 2, 3, 4).zipWithIndex) // 注意索引从0开始
println(List("a", "b", "c", "d").zipWithIndex)
println(List(("a", "b"), ("c", "d"), ("e", "f"), ("g", "h")).zipWithIndex) // 嵌套元组
输出如下:
List((1,0), (2,1), (3,2), (4,3))
List((a,0), (b,1), (c,2), (d,3))
List(((a,b),0), ((c,d),1), ((e,f),2), ((g,h),3))
Reduce
List[A].reduce
的类型签名和List[A].map
差不多,为reduce (op: (A, A) ⇒ A)
,这里就很宽松了,A
只需要是List
类型的超类就行了,但是这里不讨论这些语法。
直接上例子:
println(List(1, 2, 3, 4).reduce((a, b) => a + b)) // 返回所有元素的和
println(List(1, 2, 3, 4).reduce(_ * _)) // 返回所有元素的积
println(List(1, 2, 3, 4).map(_ + 1).reduce(_ + _)) // 可以把reduce放在map后
输出为:
10
24
14
需要注意的是,在空列表上使用reduce
是不行的:
println(List[Int]().reduce(_ * _))
会报错:
java.lang.UnsupportedOperationException: empty.reduceLeft
现在稍微练习以下,在???
处填入代码,使得列表的元素先翻倍再累乘:
println(List(1, 2, 3, 4).map(???).reduce(???))
很简单,这么写就行:
println(List(1, 2, 3, 4).map(_ * 2).reduce(_ * _))
Fold
List[A].fold
和reduce
类型,除了以不能指定规约运算的初始值。类型签名和reduce
是类似的:fold(z: A)(op: (A, A) ⇒ A): A
。
注意,它有两个参数,第一个参数z
是初始值,第二个参数是规约的函数。和reduce
不同,对于空列表它不会失效,而是会直接返回初始值。
例子来了:
println(List(1, 2, 3, 4).fold(0)(_ + _)) // 等价于用reduce的累加
println(List(1, 2, 3, 4).fold(1)(_ + _)) // 和上面的差不多,但是从1开始累加
println(List().fold(1)(_ + _)) // 和reduce不一样,fold可以在空列表上执行
输出为:
10
11
1
小小练习一下,现在要用fold
返回一个列表的累乘值的两倍,???
处怎么写:
println(List(1, 2, 3, 4).fold(???)(???))
这还用想?
println(List(1, 2, 3, 4).fold(2)(_ * _))
不过需要注意的是,除非需要容忍空列表,不然还是用reduce
更好。
Chisel中的实践——Decoupled Arbiter
现在结合上面所学,实现一个Decoupled
的仲裁器,要求有n
个Decoupled
输入和一个Decoupled
输出。仲裁器选择有效通道中索引最低的转发到输出。
几点提示:
- 如果有任何输入有效的话,
io.out.valid
就为真; - 可以考虑在模块内部给被选择的通道整个
Wire
; - 如果输出
ready
为真的话,且某个通道被选择,则对应的输入的ready
为真,(注意这里把ready
和valid
耦合到一起去了,但是这里先忽略); - 可能会用到
map
,尤其是用来返回子元素的Vec
时,比如io.in.map(_.valid)
就会返回输入Bundle的有效信号的列表; - 可能用到
PriorityMux(List[Bool, Bits])
,接受一个列表的有效信号和数据,返回第一个有效的元素; - 可能用到
Vec
的动态索引,通过一个UInt
数来索引,比如io.in(0.U)
。
一样的,在???
处填上自己的代码:
import chisel3._
import chisel3.util._
import chisel3.tester._
import chisel3.tester.RawTester.test
object MyModule extends App {
class MyRoutingArbiter(numChannels: Int) extends Module {
val io = IO(new Bundle {
val in = Vec(numChannels, Flipped(Decoupled(UInt(8.W))))
val out = Decoupled(UInt(8.W))
} )
// 在这里填上自己的代码
???
}
test(new MyRoutingArbiter(4)) { c =>
// 设置初始值
for(i <- 0 until 4) {
c.io.in(i).valid.poke(false.B)
c.io.in(i).bits.poke(i.U)
c.io.out.ready.poke(true.B)
}
c.io.out.valid.expect(false.B)
// 测试有背压的单输入有效的行为
for (i <- 0 until 4) {
c.io.in(i).valid.poke(true.B)
c.io.out.valid.expect(true.B)
c.io.out.bits.expect(i.U)
c.io.out.ready.poke(false.B)
c.io.in(i).ready.expect(false.B)
c.io.out.ready.poke(true.B)
c.io.in(i).valid.poke(false.B)
}
// 测试有背压的多输入有效的行为
c.io.in(1).valid.poke(true.B)
c.io.in(2).valid.poke(true.B)
c.io.out.bits.expect(1.U)
c.io.in(1).ready.expect(true.B)
c.io.in(0).ready.expect(false.B)
c.io.out.ready.poke(false.B)
c.io.in(1).ready.expect(false.B)
}
println("SUCCESS!!") // Scala Code: if we get here, our tests passed!
}
问号处代码应为:
io.out.valid := io.in.map(_.valid).reduce(_ || _)
val channel = PriorityMux(
io.in.map(_.valid).zipWithIndex.map { case (valid, index) => (valid, index.U) }
)
io.out.bits := io.in(channel).bits
io.in.map(_.ready).zipWithIndex.foreach { case (ready, index) =>
ready := io.out.ready && channel === index.U
}
测试通过,下面解释一下:
- 用
map
取出io.in
的valid
信号的列表,再规约进行或运算,就知道其中是否至少有一个有效; - 构造一个优先级多路选择器,输入是
valid
信号和索引才行,所以对valid
列表进行了一个zipWithIndex
,然后继续应用map
,将索引转为硬件格式,多路选择器的输出就是有效且最低索引的索引; - 输出的
bits
就是channel
索引对应的输入bits
; - 要对输入的每个
ready
信号进行操作,需要用到一个foreach
,它也是List
上的函数,虽然前面没提到但是很好懂,就是对于每个元素进行操作,无需返回值,这里用同样的方法提取除了ready
和索引,然后根据输出的ready
信号和被通道的索引来判断是否应该将这个输入的ready
信号置为有效。
有一点需要提一下,为什么不直接用PriorityMux
输出io.in.bits
呢?因为需要设置io.in.ready
位,所以必须要知道被选择的输入的索引,所以PriorityMux
用来找索引了。