由之前的GCN网络的介绍可以得知,我们需要输入两个乘数(两个节点的节点度),并输出他们乘积的-1/2次方,此处由于当时设计的booth编码的乘法器为有符号数,而此处是无符号数,实在懒得再写一份了,这里直接写个乘号,留给chisel自己去优化吧
此处设计输入的节点度位宽为8位,即支持节点最高与255个节点相连。考虑到-1/2次方对于硬件设计中的艰难性——无论是对时间和资源的消耗还是对脑力的消耗,这里对其中的主要使用查找表的方式获取对应的结果。同时由于硬件设计中寄存器位宽的确定性。而乘数的-1/2次方一定是一个小于1的数,无法直接进行表示。这里将乘数的-1/2次方所得到的结果乘以256再取整,算是对这个网络进行了一个量化操作吧。
先使用Python生成一张乘数结果和他的-1/2次方乘以256的表,Python代码如下:
import math
# 定义查找表的大小
lut_size = 300 # 2^16
# 计算查找表的值
lut_values = {}
for i in range(1, lut_size + 1):
# 计算根号倒数乘以256
value = math.sqrt(1 / i) * 256
# 四舍五入到最接近的整数
rounded_value = round(value)
# 将值转换为32位无符号整数
lut_values[i] = rounded_value & 0xFFFFFFFF
# 将查找表保存到文件
with open('sqrt_inv_lut.txt', 'w') as file:
for key, value in lut_values.items():
# 格式化输出,适合Chisel的查找表初始化
# file.write(f'({key}.U(16.W), {value}.U(8.W)),\n')
file.write(f'{key}.U -> {value}.U(8.W),\n')
print(f'Lookup table with {lut_size} entries has been saved to sqrt_inv_lut.txt')
生成的数据格式如下:
1.U -> 256.U(8.W),
2.U -> 181.U(8.W),
3.U -> 148.U(8.W),
4.U -> 128.U(8.W),
5.U -> 114.U(8.W),
6.U -> 105.U(8.W),
7.U -> 97.U(8.W),
8.U -> 91.U(8.W),
9.U -> 85.U(8.W),
10.U -> 81.U(8.W),
此处按理说可以直接完成查找表的编写了,但同时要考虑这样一件事。查找表规模太大了,对软硬件资源的消耗巨大,继续观察生成出来的数据特征。对于逐渐增大的输入,输出在前期变化大,而后期变化小。甚至最后上万个不同的输入对应一个相同的输出。所以,我们可以仅对前面变化较快的地方建立相应的查找表,而当输入变大后,我们可以直接采用条件判断的方式进行输出的划分。此处我选取输出大于15的部分建立查找表,而对于输出小于15的部分直接采用条件判断语句进行对输出值的选择。有一点值得注意的是,在输入最高为255的时候,输出四舍五入最小值为1,所以代码中的划分是精确的。
chisel代码
package FAM
import chisel3._
import chisel3.util._
import os.read
import chisel3._
/* 输入起始节点和目的节点的度,计算得到他们乘积的-1/2次方
// 假设节点度最高256位,8bit
// 这里为实现整数的运算,将起始节点度与目的节点度的-1/2次方的乘积 * 256再取整
// 使用查找表完成乘法后的计算
*/
class SqrtInv(val Degree_Data_width: Int = 8) extends Module {
val io = IO(new Bundle {
val start_point_degree = Input(UInt(Degree_Data_width.W))
val end_point_degree = Input(UInt(Degree_Data_width.W))
val start = Input(Bool())
val done = Output(Bool())
val out = Output(UInt((Degree_Data_width).W))
})
// 初始化状态机。共四个状态:空闲状态、输入节点乘法计算,查找表,输出
val state = RegInit(0.U(2.W))
val idle :: multiply :: lut :: done :: Nil = Enum(4)
// 计算乘积
val mulResult = RegInit(0.U((2*Degree_Data_width).W))
val result = RegInit(0.U(8.W))
// 初始化LUT
val LUT_array = Array(
1.U -> 255.U(8.W),
2.U -> 181.U(8.W),
3.U -> 148.U(8.W),
4.U -> 128.U(8.W),
5.U -> 114.U(8.W),
6.U -> 105.U(8.W),
7.U -> 97.U(8.W),
8.U -> 91.U(8.W),
9.U -> 85.U(8.W),
10.U -> 81.U(8.W),
11.U -> 77.U(8.W),
12.U -> 74.U(8.W),
13.U -> 71.U(8.W),
14.U -> 68.U(8.W),
15.U -> 66.U(8.W),
16.U -> 64.U(8.W),
17.U -> 62.U(8.W),
18.U -> 60.U(8.W),
19.U -> 59.U(8.W),
20.U -> 57.U(8.W),
21.U -> 56.U(8.W),
22.U -> 55.U(8.W),
23.U -> 53.U(8.W),
24.U -> 52.U(8.W),
25.U -> 51.U(8.W),
26.U -> 50.U(8.W),
27.U -> 49.U(8.W),
28.U -> 48.U(8.W),
29.U -> 48.U(8.W),
30.U -> 47.U(8.W),
31.U -> 46.U(8.W),
32.U -> 45.U(8.W),
33.U -> 45.U(8.W),
34.U -> 44.U(8.W),
35.U -> 43.U(8.W),
36.U -> 43.U(8.W),
37.U -> 42.U(8.W),
38.U -> 42.U(8.W),
39.U -> 41.U(8.W),
40.U -> 40.U(8.W),
41.U -> 40.U(8.W),
42.U -> 40.U(8.W),
43.U -> 39.U(8.W),
44.U -> 39.U(8.W),
45.U -> 38.U(8.W),
46.U -> 38.U(8.W),
47.U -> 37.U(8.W),
48.U -> 37.U(8.W),
49.U -> 37.U(8.W),
50.U -> 36.U(8.W),
51.U -> 36.U(8.W),
52.U -> 36.U(8.W),
53.U -> 35.U(8.W),
54.U -> 35.U(8.W),
55.U -> 35.U(8.W),
56.U -> 34.U(8.W),
57.U -> 34.U(8.W),
58.U -> 34.U(8.W),
59.U -> 33.U(8.W),
60.U -> 33.U(8.W),
61.U -> 33.U(8.W),
62.U -> 33.U(8.W),
63.U -> 32.U(8.W),
64.U -> 32.U(8.W),
65.U -> 32.U(8.W),
66.U -> 32.U(8.W),
67.U -> 31.U(8.W),
68.U -> 31.U(8.W),
69.U -> 31.U(8.W),
70.U -> 31.U(8.W),
71.U -> 30.U(8.W),
72.U -> 30.U(8.W),
73.U -> 30.U(8.W),
74.U -> 30.U(8.W),
75.U -> 30.U(8.W),
76.U -> 29.U(8.W),
77.U -> 29.U(8.W),
78.U -> 29.U(8.W),
79.U -> 29.U(8.W),
80.U -> 29.U(8.W),
81.U -> 28.U(8.W),
82.U -> 28.U(8.W),
83.U -> 28.U(8.W),
84.U -> 28.U(8.W),
85.U -> 28.U(8.W),
86.U -> 28.U(8.W),
87.U -> 27.U(8.W),
88.U -> 27.U(8.W),
89.U -> 27.U(8.W),
90.U -> 27.U(8.W),
91.U -> 27.U(8.W),
92.U -> 27.U(8.W),
93.U -> 27.U(8.W),
94.U -> 26.U(8.W),
95.U -> 26.U(8.W),
96.U -> 26.U(8.W),
97.U -> 26.U(8.W),
98.U -> 26.U(8.W),
99.U -> 26.U(8.W),
100.U -> 26.U(8.W),
101.U -> 25.U(8.W),
102.U -> 25.U(8.W),
103.U -> 25.U(8.W),
104.U -> 25.U(8.W),
105.U -> 25.U(8.W),
106.U -> 25.U(8.W),
107.U -> 25.U(8.W),
108.U -> 25.U(8.W),
109.U -> 25.U(8.W),
110.U -> 24.U(8.W),
111.U -> 24.U(8.W),
112.U -> 24.U(8.W),
113.U -> 24.U(8.W),
114.U -> 24.U(8.W),
115.U -> 24.U(8.W),
116.U -> 24.U(8.W),
117.U -> 24.U(8.W),
118.U -> 24.U(8.W),
119.U -> 23.U(8.W),
120.U -> 23.U(8.W),
121.U -> 23.U(8.W),
122.U -> 23.U(8.W),
123.U -> 23.U(8.W),
124.U -> 23.U(8.W),
125.U -> 23.U(8.W),
126.U -> 23.U(8.W),
127.U -> 23.U(8.W),
128.U -> 23.U(8.W),
129.U -> 23.U(8.W),
130.U -> 22.U(8.W),
131.U -> 22.U(8.W),
132.U -> 22.U(8.W),
133.U -> 22.U(8.W),
134.U -> 22.U(8.W),
135.U -> 22.U(8.W),
136.U -> 22.U(8.W),
137.U -> 22.U(8.W),
138.U -> 22.U(8.W),
139.U -> 22.U(8.W),
140.U -> 22.U(8.W),
141.U -> 22.U(8.W),
142.U -> 21.U(8.W),
143.U -> 21.U(8.W),
144.U -> 21.U(8.W),
145.U -> 21.U(8.W),
146.U -> 21.U(8.W),
147.U -> 21.U(8.W),
148.U -> 21.U(8.W),
149.U -> 21.U(8.W),
150.U -> 21.U(8.W),
151.U -> 21.U(8.W),
152.U -> 21.U(8.W),
153.U -> 21.U(8.W),
154.U -> 21.U(8.W),
155.U -> 21.U(8.W),
156.U -> 20.U(8.W),
157.U -> 20.U(8.W),
158.U -> 20.U(8.W),
159.U -> 20.U(8.W),
160.U -> 20.U(8.W),
161.U -> 20.U(8.W),
162.U -> 20.U(8.W),
163.U -> 20.U(8.W),
164.U -> 20.U(8.W),
165.U -> 20.U(8.W),
166.U -> 20.U(8.W),
167.U -> 20.U(8.W),
168.U -> 20.U(8.W),
169.U -> 20.U(8.W),
170.U -> 20.U(8.W),
171.U -> 20.U(8.W),
172.U -> 20.U(8.W),
173.U -> 19.U(8.W),
174.U -> 19.U(8.W),
175.U -> 19.U(8.W),
176.U -> 19.U(8.W),
177.U -> 19.U(8.W),
178.U -> 19.U(8.W),
179.U -> 19.U(8.W),
180.U -> 19.U(8.W),
181.U -> 19.U(8.W),
182.U -> 19.U(8.W),
183.U -> 19.U(8.W),
184.U -> 19.U(8.W),
185.U -> 19.U(8.W),
186.U -> 19.U(8.W),
187.U -> 19.U(8.W),
188.U -> 19.U(8.W),
189.U -> 19.U(8.W),
190.U -> 19.U(8.W),
191.U -> 19.U(8.W),
192.U -> 18.U(8.W),
193.U -> 18.U(8.W),
194.U -> 18.U(8.W),
195.U -> 18.U(8.W),
196.U -> 18.U(8.W),
197.U -> 18.U(8.W),
198.U -> 18.U(8.W),
199.U -> 18.U(8.W),
200.U -> 18.U(8.W),
201.U -> 18.U(8.W),
202.U -> 18.U(8.W),
203.U -> 18.U(8.W),
204.U -> 18.U(8.W),
205.U -> 18.U(8.W),
206.U -> 18.U(8.W),
207.U -> 18.U(8.W),
208.U -> 18.U(8.W),
209.U -> 18.U(8.W),
210.U -> 18.U(8.W),
211.U -> 18.U(8.W),
212.U -> 18.U(8.W),
213.U -> 18.U(8.W),
214.U -> 17.U(8.W),
215.U -> 17.U(8.W),
216.U -> 17.U(8.W),
217.U -> 17.U(8.W),
218.U -> 17.U(8.W),
219.U -> 17.U(8.W),
220.U -> 17.U(8.W),
221.U -> 17.U(8.W),
222.U -> 17.U(8.W),
223.U -> 17.U(8.W),
224.U -> 17.U(8.W),
225.U -> 17.U(8.W),
226.U -> 17.U(8.W),
227.U -> 17.U(8.W),
228.U -> 17.U(8.W),
229.U -> 17.U(8.W),
230.U -> 17.U(8.W),
231.U -> 17.U(8.W),
232.U -> 17.U(8.W),
233.U -> 17.U(8.W),
234.U -> 17.U(8.W),
235.U -> 17.U(8.W),
236.U -> 17.U(8.W),
237.U -> 17.U(8.W),
238.U -> 17.U(8.W),
239.U -> 17.U(8.W),
240.U -> 17.U(8.W),
241.U -> 16.U(8.W),
242.U -> 16.U(8.W),
243.U -> 16.U(8.W),
244.U -> 16.U(8.W),
245.U -> 16.U(8.W),
246.U -> 16.U(8.W),
247.U -> 16.U(8.W),
248.U -> 16.U(8.W),
249.U -> 16.U(8.W),
250.U -> 16.U(8.W),
251.U -> 16.U(8.W),
252.U -> 16.U(8.W),
253.U -> 16.U(8.W),
254.U -> 16.U(8.W),
255.U -> 16.U(8.W),
256.U -> 16.U(8.W),
257.U -> 16.U(8.W),
258.U -> 16.U(8.W),
259.U -> 16.U(8.W),
260.U -> 16.U(8.W),
261.U -> 16.U(8.W),
262.U -> 16.U(8.W),
263.U -> 16.U(8.W),
264.U -> 16.U(8.W),
265.U -> 16.U(8.W),
266.U -> 16.U(8.W),
267.U -> 16.U(8.W),
268.U -> 16.U(8.W),
269.U -> 16.U(8.W),
270.U -> 16.U(8.W),
271.U -> 16.U(8.W),
272.U -> 16.U(8.W)
)
// 状态机执行
switch(state) {
is(idle) {
when(io.start) {
state := multiply
}
}
is(multiply) {
mulResult := io.start_point_degree * io.end_point_degree
state := lut
}
is(lut) {
// 使用输入值作为索引来查找结果
// 接受选择信号、一个默认值,一个选择表。如果匹配成功,则按照匹配值输出,否则按默认值输出
// 观察计算结果,前面数值变化相对频繁,而后面数值变化不大,这里将整个过程进行分类。变化相对频繁的地方使用查找表进行查找,
// 而没那么频繁的使用条件判断所属区间进行处理
when(mulResult < 273.U) { // 查找表所属区间,即LUT_depth
result := MuxLookup(mulResult, 0.U, LUT_array)
}
.elsewhen(mulResult < 312.U) { // 15
result := 15.U
}
.elsewhen(mulResult < 360.U) { // 14
result := 14.U
}
.elsewhen(mulResult < 420.U) { // 13
result := 13.U
}
.elsewhen(mulResult < 496.U) { // 12
result := 12.U
}
.elsewhen(mulResult < 595.U) { // 11
result := 11.U
}
.elsewhen(mulResult < 727.U) { // 10
result := 10.U
}
.elsewhen(mulResult < 908.U) { // 9
result := 9.U
}
.elsewhen(mulResult < 1166.U) { // 8
result := 8.U
}
.elsewhen(mulResult < 1552.U) { // 7
result := 7.U
}
.elsewhen(mulResult < 2167.U) { // 6
result := 6.U
}
.elsewhen(mulResult < 3237.U) { // 5
result := 5.U
}
.elsewhen(mulResult < 5350.U) { // 4
result := 4.U
}
.elsewhen(mulResult < 10487.U) { // 3
result := 3.U
}
.elsewhen(mulResult < 29128.U) { // 2
result := 2.U
}
.otherwise { // 1
result := 1.U
}
state := done
}
is(done) {
state := idle
}
}
io.done := (state === done)
io.out := result
}
// 实例化模块并运行测试,同时生成Verilog代码
object SqrtInv extends App {
(new chisel3.stage.ChiselStage).emitVerilog(new SqrtInv(), Array("--target-dir", "./verilog/FAM"))
}
测试代码
import scala.util.Random
import org.scalatest._
import chiseltest._
import chisel3._
import FAM.SqrtInv
// 乘累加器的测试类
class Power_1_2Test extends FreeSpec with ChiselScalatestTester {
"Power -1/2 should pass" in {
test(new SqrtInv)
.withAnnotations(Seq(WriteVcdAnnotation)) // generate the .vcd waveform file as output
{ c =>
println("Start Testing")
for (i <- 0 until 10) {
val a = Random.nextInt(256) // 生成0到255之间的随机数
val b = Random.nextInt(256)
c.io.start_point_degree.poke(a.U) // 将随机数a作为无符号数输入
c.io.end_point_degree.poke(b.U) // 将随机数b作为无符号数输入
c.io.start.poke(true.B)
c.clock.step(2)
while (c.io.done.peekBoolean() === false) {
c.clock.step(1)
}
val expectedResult = math.round(256/math.sqrt(a * b)) // 计算预期乘积
val actualResult = c.io.out.peek().litValue.toLong // 获取实际乘积
/*
c: 这是测试环境中MAC模块的实例。
c.io.result: 这是指向模块输出端口result的引用。
peek(): 这是一个Chisel测试方法,用于在不推进时钟的情况下读取端口的当前值。
litValue: 这是一个方法,用于从Chisel的Data类型中提取实际的Scala值(在这个例子中是BigInt)
*/
println(s"Iteration: $i, A: $a, B: $b, Expected Result: $expectedResult, Actual Result: $actualResult")
assert(actualResult === expectedResult, s"Product is incorrect at iteration $i!\n Start_point_degree is $a, end point degree is $b.\n Expected: $expectedResult, Actual: $actualResult")
}
}
}
}
测试结果(使用vscode的metals插件完成测试,也可直接sbt test对所有文件进行测试)
Power_1_2Test
Start Testing
Iteration: 0, A: 167, B: 77, Expected Result: 2, Actual Result: 2
Iteration: 1, A: 180, B: 14, Expected Result: 5, Actual Result: 5
Iteration: 2, A: 212, B: 114, Expected Result: 2, Actual Result: 2
Iteration: 3, A: 171, B: 195, Expected Result: 1, Actual Result: 1
Iteration: 4, A: 196, B: 219, Expected Result: 1, Actual Result: 1
Iteration: 5, A: 101, B: 209, Expected Result: 2, Actual Result: 2
Iteration: 6, A: 138, B: 111, Expected Result: 2, Actual Result: 2
Iteration: 7, A: 113, B: 69, Expected Result: 3, Actual Result: 3
Iteration: 8, A: 115, B: 245, Expected Result: 2, Actual Result: 2
Iteration: 9, A: 219, B: 153, Expected Result: 1, Actual Result: 1
- Power -1/2 should pass
Execution took 2.75s
1 tests, 1 passed
All tests in Power_1_2Test passed