Chisel教程——11.用Scala的高阶函数让Chisel代码更优雅

用Scala的高阶函数让Chisel代码更优雅

动机

前面模块中讨厌的for循环很冗长,破坏了函数式编程的目的,这一篇将会介绍高阶函数,让我们写生成器的时候更快。

FIR的卷积操作居然一行代码就能实现?

前面我们写了FIR过滤器的实现,其中的卷积部分是这么写的:

val muls = Wire(Vec(length, UInt(8.W)))
for(i <- 0 until length) {
  if(i == 0) muls(i) := io.in * io.consts(i)
  else       muls(i) := regs(i - 1) * io.consts(i)
}

val scan = Wire(Vec(length, UInt(8.W)))
for(i <- 0 until length) {
  if(i == 0) scan(i) := muls(i)
  else scan(i) := muls(i) + scan(i - 1)
}

io.out := scan(length - 1)

简单说一下它的基本思想,首先把io.in的每个元素与对应的io.const相乘,然后存放到muls中,然后muls中的元素累加到scan,其中scan(0) = muls(0)scan(1) = scan(0) + muls(1) = muls(0) + muls(1),一般化就是scan(n) = scan(n-1) + muls(n) = muls(0) + ... + muls(n-1) + muls(n),最后scan的最后一个元素(等于muls所有元素的值)被赋值到io.out

但是啊,真的很繁琐,罗里吧嗦半天就为了实现一个简单的操作。实际上呢,上面的操作可以写成一行代码:

io.out := (taps zip io.consts).map { case (a, b) => a * b }.reduce(_ + _)

它是怎么办到的呢?下面一点点来分析:

  1. 假设taps是所有采样的列表,也就是taps(0) = io.intaps(1) = regs(0)等;
  2. (taps zip io.consts)接受两个列表,tapsio.const,然后组合他们为一个列表,其每个元素都是输入的相应位置的元素的元组,具体来说,它的值就像这样:[(taps(0), io.consts(0)), (taps(1), io.consts(1)), ..., (taps(n), io.consts(n))]。注意,因为.是可以省略的,所以这个其实等价于(taps.zip(io.const))
  3. .map {case (a, b) => a*b}应用了一个匿名函数,这个函数接受一个二元素的元组然后返回他们的乘积,这个匿名函数应用在了列表的元素上并返回结果也是列表,结果为[taps(0) * io.consts(0), taps(1) * io.consts(1), ..., taps(n) * io.consts(n)],和之前的muls是等价的。现在先简单了解写匿名函数的语法,后面还会详细介绍;
  4. 最后,.reduce(_ + _)在列表的元素上应用了函数(元素之间相加)。然而,这里同样接受两个参数,第一个是当前的累加结果,第二个是列表元素(第一次迭代中两个参数都是列表元素)。这是通过圆括号里面的两个下划线给定的。假设从左到右遍历,结果会是(((muls(0) + muls(1)) + muls(2)) + ...) + muls(n),最先计算的有更深的括号嵌套。这个结果就是卷积的输出。

作为参数的函数

正式地来讲,像mapreduce这样的函数就是高阶函数,因为它们是以函数为参数的函数。结果也证明,使用高阶函数能够省不少事,可以封装一个一般的计算模式,允许你在写代码是专注于应用整体的逻辑而不是控制流,能写出非常简洁的代码。

指定函数的不同方法

上面已经提到了两种了,这里总结一下:

  1. 对于每个元素只会引用一次的函数,可能可以使用下划线_来指代每个元素,在上面的例子中,reduce的参数函数接受两个元素就可以被表示为_ + _。虽然方便,但这受限于极少数满足一些复杂规则的情况,所以如果不行的话,可以试试下面的方法:
  2. 显式指定输入参数列表。上面的规约操作可以被显式写成(a, b) => a + b,用的是把参数列表放在括号里的一般形式,然后跟着=>符号,再跟着引用这些参数的函数体;
  3. 如果需要解包元组的时候,就用case语句,就像case (a, b) => a + b这里的用法一样。这种情况是接受了单个参数,即一个两元素的元组,然后将其解包为变量ab,然后用于后面的函数体。

Scala中的实践

上上一篇文章中,我们介绍了Scala集合类API中的主要类,比如List这种。这些高阶函数其实也是这些API的一部分,比如上面的mapreduce都是List上的API。这一小节我们通过一些例子来熟悉这些方法。在例子中,简洁起见我们都在Scala的数字(Int)上进行操作,但是因为Chisel的操作符也是类似的,所以这些概念可以泛化。

Map

List[A].map有类型签名map[B](f: (A) ⇒ B): List[B]。后面会有专门的一篇讲解关于类型的知识,现在就把这个类型AB看成是Int或者SInt,意味着它们可以是软件类型也可以是硬件类型。

说人话就是,它接受一个类型为(f: (A) ⇒ B)的参数,或者是一个接受两个参数的函数,第一个参数类型为A,即和输入列表的元素类型一致,第二个参数类型就随便了,什么类型都可以,然后map会返回一个类型B的列表,即参数函数的返回值类型。

因为我们已经解释过了FIR例子中列表的行为,现在就直接看例子吧:

println(List(1, 2, 3, 4).map(x => x + 1))  // 函数中的显式参数列表
println(List(1, 2, 3, 4).map(_ + 1))  // 和上面的等价,但是隐式的
println(List(1, 2, 3, 4).map(_.toString + "a"))  // 输出元素类型可和输入不同

println(List((1, 5), (2, 6), (3, 7), (4, 8)).map { case (x, y) => x*y })  // 用case解包元组,注意这里用的是大括号

// 提一嘴,Scala中有构造连续数字列表的语法
println(0 to 10)  // to是inclusive的, 这里的10是包括在内的
println(0 until 10)  // until是exclusive的,这里的10不包括在内

// 上面生成的和列表的行为基本一致,生成索引的时候很有用
val myList = List("a", "b", "c", "d")
println((0 until 4).map(myList(_)))

输出如下:

List(2, 3, 4, 5)
List(2, 3, 4, 5)
List(1a, 2a, 3a, 4a)
List(5, 12, 21, 32)
Range 0 to 10
Range 0 until 10
Vector(a, b, c, d)

再来个简单的练习,想要让列表中的每个元素的翻倍,???处填什么代码呢?

println(List(1, 2, 3, 4).map(???))

显然,填_ * 2就行了。

zipWithIndex

zipWithIndex的类型签名是zipWithIndex: List[(A, Int)]

也就是说不接受任何参数,但是返回一个列表,其每个元素都是源数据和它的索引(第一个元素索引为0)。所以说,List("a", "b", "c", "d").zipWithIndex会返回List(("a", 0), ("b", 1), ("c", 2), ("d", 3))

这在某些操作中,需要元素索引的场合特别有用。

这个也很简单,直接上例子:

println(List(1, 2, 3, 4).zipWithIndex)  // 注意索引从0开始
println(List("a", "b", "c", "d").zipWithIndex)
println(List(("a", "b"), ("c", "d"), ("e", "f"), ("g", "h")).zipWithIndex)  // 嵌套元组

输出如下:

List((1,0), (2,1), (3,2), (4,3))
List((a,0), (b,1), (c,2), (d,3))
List(((a,b),0), ((c,d),1), ((e,f),2), ((g,h),3))

Reduce

List[A].reduce的类型签名和List[A].map差不多,为reduce (op: (A, A) ⇒ A),这里就很宽松了,A只需要是List类型的超类就行了,但是这里不讨论这些语法。

直接上例子:

println(List(1, 2, 3, 4).reduce((a, b) => a + b))  // 返回所有元素的和
println(List(1, 2, 3, 4).reduce(_ * _))  // 返回所有元素的积
println(List(1, 2, 3, 4).map(_ + 1).reduce(_ + _))  // 可以把reduce放在map后

输出为:

10
24
14

需要注意的是,在空列表上使用reduce是不行的:

println(List[Int]().reduce(_ * _))

会报错:

java.lang.UnsupportedOperationException: empty.reduceLeft

现在稍微练习以下,在???处填入代码,使得列表的元素先翻倍再累乘:

println(List(1, 2, 3, 4).map(???).reduce(???))

很简单,这么写就行:

println(List(1, 2, 3, 4).map(_ * 2).reduce(_ * _))

Fold

List[A].foldreduce类型,除了以不能指定规约运算的初始值。类型签名和reduce是类似的:fold(z: A)(op: (A, A) ⇒ A): A

注意,它有两个参数,第一个参数z是初始值,第二个参数是规约的函数。和reduce不同,对于空列表它不会失效,而是会直接返回初始值。

例子来了:

println(List(1, 2, 3, 4).fold(0)(_ + _))  // 等价于用reduce的累加
println(List(1, 2, 3, 4).fold(1)(_ + _))  // 和上面的差不多,但是从1开始累加
println(List().fold(1)(_ + _))  // 和reduce不一样,fold可以在空列表上执行

输出为:

10
11
1

小小练习一下,现在要用fold返回一个列表的累乘值的两倍,???处怎么写:

println(List(1, 2, 3, 4).fold(???)(???))

这还用想?

println(List(1, 2, 3, 4).fold(2)(_ * _))

不过需要注意的是,除非需要容忍空列表,不然还是用reduce更好。

Chisel中的实践——Decoupled Arbiter

现在结合上面所学,实现一个Decoupled的仲裁器,要求有nDecoupled输入和一个Decoupled输出。仲裁器选择有效通道中索引最低的转发到输出。

几点提示:

  1. 如果有任何输入有效的话,io.out.valid就为真;
  2. 可以考虑在模块内部给被选择的通道整个Wire
  3. 如果输出ready为真的话,且某个通道被选择,则对应的输入的ready为真,(注意这里把readyvalid耦合到一起去了,但是这里先忽略);
  4. 可能会用到map,尤其是用来返回子元素的Vec时,比如io.in.map(_.valid)就会返回输入Bundle的有效信号的列表;
  5. 可能用到PriorityMux(List[Bool, Bits]),接受一个列表的有效信号和数据,返回第一个有效的元素;
  6. 可能用到Vec的动态索引,通过一个UInt数来索引,比如io.in(0.U)

一样的,在???处填上自己的代码:

import chisel3._
import chisel3.util._
import chisel3.tester._
import chisel3.tester.RawTester.test

object MyModule extends App {
  class MyRoutingArbiter(numChannels: Int) extends Module {
    val io = IO(new Bundle {
      val in = Vec(numChannels, Flipped(Decoupled(UInt(8.W))))
      val out = Decoupled(UInt(8.W))
    } )

    // 在这里填上自己的代码
    ???
  }

  test(new MyRoutingArbiter(4)) { c =>
    // 设置初始值
    for(i <- 0 until 4) {
        c.io.in(i).valid.poke(false.B)
        c.io.in(i).bits.poke(i.U)
        c.io.out.ready.poke(true.B)
    }

    c.io.out.valid.expect(false.B)

    // 测试有背压的单输入有效的行为
    for (i <- 0 until 4) {
        c.io.in(i).valid.poke(true.B)
        c.io.out.valid.expect(true.B)
        c.io.out.bits.expect(i.U)

        c.io.out.ready.poke(false.B)
        c.io.in(i).ready.expect(false.B)

        c.io.out.ready.poke(true.B)
        c.io.in(i).valid.poke(false.B)
    }

    // 测试有背压的多输入有效的行为
    c.io.in(1).valid.poke(true.B)
    c.io.in(2).valid.poke(true.B)
    c.io.out.bits.expect(1.U)
    c.io.in(1).ready.expect(true.B)
    c.io.in(0).ready.expect(false.B)

    c.io.out.ready.poke(false.B)
    c.io.in(1).ready.expect(false.B)
  }

  println("SUCCESS!!") // Scala Code: if we get here, our tests passed!
}

问号处代码应为:

io.out.valid := io.in.map(_.valid).reduce(_ || _)
val channel = PriorityMux(
  io.in.map(_.valid).zipWithIndex.map { case (valid, index) => (valid, index.U) }
)
io.out.bits := io.in(channel).bits
io.in.map(_.ready).zipWithIndex.foreach { case (ready, index) =>
  ready := io.out.ready && channel === index.U
}

测试通过,下面解释一下:

  1. map取出io.invalid信号的列表,再规约进行或运算,就知道其中是否至少有一个有效;
  2. 构造一个优先级多路选择器,输入是valid信号和索引才行,所以对valid列表进行了一个zipWithIndex,然后继续应用map,将索引转为硬件格式,多路选择器的输出就是有效且最低索引的索引;
  3. 输出的bits就是channel索引对应的输入bits
  4. 要对输入的每个ready信号进行操作,需要用到一个foreach,它也是List上的函数,虽然前面没提到但是很好懂,就是对于每个元素进行操作,无需返回值,这里用同样的方法提取除了ready和索引,然后根据输出的ready信号和被通道的索引来判断是否应该将这个输入的ready信号置为有效。

有一点需要提一下,为什么不直接用PriorityMux输出io.in.bits呢?因为需要设置io.in.ready位,所以必须要知道被选择的输入的索引,所以PriorityMux用来找索引了。

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值