1. FFT背景知识
FFT背景知识可参看博客十分简明易懂的FFT(快速傅里叶变换)。
2. Halo中的FFT代码实现
在4核8G ubuntu16.04服务器上运行:
cargo test test_fft -- --nocapture
test_fft
函数中实现的是对两个999阶(1000个系数)多项式的乘法运算,在该函数内,分别进行了直接乘法运算naive_product
和通过FFT实现的乘法运算multiply_polynomials
。
2.1 a
和b
系数列表均扩展为
2
e
x
p
2^{exp}
2exp
multiply_polynomials
函数中会首先将两个多项式相乘后的系数总数扩展为
2
e
x
p
2^{exp}
2exp,将a
和b
系数列表补零扩展为
2
e
x
p
2^{exp}
2exp:
let degree_of_result = (a.len() - 1) + (b.len() - 1); //1998
let coeffs_of_result = degree_of_result + 1; //1999
// Compute the size of our evaluation domain
let mut m = 1; //2048
let mut exp = 0; //11
while m < coeffs_of_result {
m *= 2;
exp += 1;
// The pairing-friendly curve may not be able to support
// large enough (radix2) evaluation domains.
if exp >= F::S {
panic!("polynomial too large");
}
}
//将`a`和`b`系数列表补零扩展为$2^{exp}$
// Extend the vectors with zeroes
a.resize(m, F::zero());
b.resize(m, F::zero());
2.2 获取 2 e x p 2^{exp} 2exp-th primitive root of unity
F::ALPHA
为
2
32
2^{32}
232-th primitive root of unity,基于该值获取相应的
2
e
x
p
2^{exp}
2exp-th primitive root of unity:
// Compute alpha, the 2^exp primitive root of unity
let mut alpha = F::ALPHA;
for _ in exp..F::S {
alpha = alpha.square();
}
//alpha为$2^{exp}$-th primitive root of unity
2.3 分别对a
和b
系数列表做FFT
alpha为$2^{exp}$-th primitive root of unity,exp=11
best_fft(&mut a, alpha, exp);
best_fft(&mut b, alpha, exp);
注意best_fft(&mut a, alpha, exp);
返回的数组
a
a
a中(对多项式
A
(
x
)
=
a
0
+
a
1
x
+
a
2
x
2
+
.
.
.
+
a
n
x
n
A(x)=a_0+a_1x+a_2x^2+...+a_nx^n
A(x)=a0+a1x+a2x2+...+anxn)依次
x
x
x取
w
n
0
,
w
n
1
,
w
n
2
,
.
.
.
,
w
n
(
n
−
1
)
w_n^0,w_n^1,w_n^2,...,w_n^{(n-1)}
wn0,wn1,wn2,...,wn(n−1)的值
a
=
[
A
(
w
n
0
)
,
A
(
w
n
1
)
,
A
(
w
n
2
)
,
.
.
.
,
A
(
w
n
(
n
−
1
)
)
]
a=[A(w_n^0), A(w_n^1),A(w_n^2),...,A(w_n^{(n-1)})]
a=[A(wn0),A(wn1),A(wn2),...,A(wn(n−1))]。
也就是说,通过best_fft
函数,可将系数表示的多项式转换为点值表示:
(
w
n
0
,
A
(
w
n
0
)
)
,
.
.
.
.
,
(
w
n
(
n
−
1
)
,
A
(
w
n
(
n
−
1
)
)
)
(w_n^0,A(w_n^0)),....,(w_n^{(n-1)},A(w_n^{(n-1)}))
(wn0,A(wn0)),....,(wn(n−1),A(wn(n−1)))。
best_fft
中会针对exp
与cpu核数的关系来决定调用串行方式serial_fft
还是并行方式parallel_fft
。
fn best_fft<F: Field>(a: &mut [F], omega: F, log_n: u32) {
let cpus = num_cpus::get(); //4
let log_cpus = log2_floor(cpus); //2
if log_n <= log_cpus {
serial_fft(a, omega, log_n);
} else { //11>2
parallel_fft(a, omega, log_n, log_cpus);
}
}
2.3.1 并行FFT算法parallel_fft
// omega为$2^{exp}$-th primitive root of unity,exp=11, log_n=11, log_cpus=2
fn parallel_fft<F: Field>(a: &mut [F], omega: F, log_n: u32, log_cpus: u32) {
assert!(log_n >= log_cpus);
let num_cpus = 1 << log_cpus; //4
let log_new_n = log_n - log_cpus; //11-2=9
let mut tmp = vec![vec![F::zero(); 1 << log_new_n]; num_cpus]; // [2^2行 * 2^9列] 矩阵。行数2^2=4为cpu核数。
let new_omega = omega.pow(&[num_cpus as u64, 0, 0, 0]); // new_omega=power_mod(omega, num_cpus,p),对应的new_omega为2^{exp-log_n}即2^9-th primitive root of unity。
thread::scope(|scope| {
let a = &*a;
// 1)tmp为 [2^2行 * 2^9列] 矩阵,逐行操作。
for (j, tmp) in tmp.iter_mut().enumerate() {
scope.spawn(move |_| { //采用多线程方式,所有行同时处理。
// Shuffle into a sub-FFT
let omega_j = omega.pow(&[j as u64, 0, 0, 0]); // power_mod(omega,j,p)
let omega_step = omega.pow(&[(j as u64) << log_new_n, 0, 0, 0]); //power_mod(omega, j*2^9,p)
let mut elt = F::one(); //1
for i in 0..(1 << log_new_n) { // 2)再逐列处理
for s in 0..num_cpus { // 3)再逐CPU处理
let idx = (i + (s << log_new_n)) % (1 << log_n);
//$idx$用于取系数列表`a`中的相应的值。
//s=0时,idx的取值范围为0~511;s=1时,idx:512~1023;s=2时,idx:1024~1535;s=3时,idx:1536~2047。
let mut t = a[idx];
t *= elt;
tmp[i] += t;
elt *= omega_step;
}
elt *= omega_j;
}
// Perform sub-FFT
serial_fft(tmp, new_omega, log_new_n);
});
}
})
.unwrap();
// Unshuffle
let mask = (1 << log_cpus) - 1;
for (idx, a) in a.iter_mut().enumerate() {
*a = tmp[idx & mask][idx >> log_cpus];
}
}
FFT算法的本质是将以系数表示的多项式转化为以点值表示。对于方程式:
y
=
a
0
+
a
1
x
+
a
2
x
2
+
.
.
.
.
+
a
n
−
1
x
n
−
1
y=a_0+a_1x+a_2x^2+....+a_{n-1}x^{n-1}
y=a0+a1x+a2x2+....+an−1xn−1
其中,
n
=
2
e
x
p
n=2^{exp}
n=2exp,上例中
e
x
p
=
11
,
n
=
2048
exp=11,n=2048
exp=11,n=2048,omega【表示为
w
n
1
w_n^1
wn1】为n-th root of unity(即满足
(
w
n
1
)
n
=
1
(w_n^1)^{n}=1
(wn1)n=1)。
转为
n
n
n个互不相同的点值序列
(
x
0
,
y
0
)
,
(
x
1
,
y
1
)
,
(
x
2
,
y
2
)
,
.
.
.
.
.
.
,
(
x
n
−
1
,
y
n
−
1
)
(x_0,y_0),(x_1,y_1),(x_2,y_2),......,(x_{n-1},y_{n-1})
(x0,y0),(x1,y1),(x2,y2),......,(xn−1,yn−1),其中
x
k
=
(
w
n
1
)
k
=
w
n
k
x_k=(w_n^1)^k=w_n^k
xk=(wn1)k=wnk。上例中parallel_fft
函数中tmp
数组内,存储的即为所有的
y
y
y值,即
t
m
p
[
k
]
=
y
k
tmp[k]=y_k
tmp[k]=yk。
以4核(
p
=
4
p=4
p=4)CPU为例,支持将2047阶(
n
=
2048
,
n
/
p
=
512
n=2048,n/p=512
n=2048,n/p=512)多项式拆分为四个线程分别执行:
A
(
x
)
=
a
0
+
a
1
x
+
a
2
x
2
+
.
.
.
+
a
511
x
511
+
x
512
(
a
512
+
a
513
x
+
a
514
x
2
+
.
.
.
+
a
1023
x
511
)
+
x
1024
(
a
1024
+
a
1025
x
+
a
1026
x
2
+
.
.
.
+
a
1535
x
511
)
+
x
1536
(
a
1536
+
a
1537
x
+
a
1538
x
2
+
.
.
.
+
a
2047
x
511
)
A(x)=a_0+a_1x+a_2x^2+...+a_{511}x^{511}\\ \ \ \ +x^{512}(a_{512}+a_{513}x+a_{514}x^2+... +a_{1023}x^{511})\\ \ \ \ +x^{1024}(a_{1024}+a_{1025}x+a_{1026}x^2+...+a_{1535}x^{511})\\ \ \ \ +x^{1536}(a_{1536}+a_{1537}x+a_{1538}x^2+...+a_{2047}x^{511})
A(x)=a0+a1x+a2x2+...+a511x511 +x512(a512+a513x+a514x2+...+a1023x511) +x1024(a1024+a1025x+a1026x2+...+a1535x511) +x1536(a1536+a1537x+a1538x2+...+a2047x511)
逐列展开:
A
(
x
)
=
C
0
(
x
512
)
+
x
C
1
(
x
512
)
+
x
2
C
2
(
x
512
)
+
.
.
.
+
x
511
C
511
(
x
512
)
A(x)=C_0(x^{512})+xC_1(x^{512})+x^2C_2(x^{512})+...+x^{511}C_{511}(x^{512})
A(x)=C0(x512)+xC1(x512)+x2C2(x512)+...+x511C511(x512)
其中:
C
0
(
x
)
=
a
0
+
a
512
x
+
a
1024
x
2
+
a
1536
x
3
C_0(x)=a_0+a_{512}x+a_{1024}x^2+a_{1536}x^3
C0(x)=a0+a512x+a1024x2+a1536x3
C
1
(
x
)
=
a
1
+
a
513
x
+
a
1025
x
2
+
a
1537
x
3
C_1(x)=a_1+a_{513}x+a_{1025}x^2+a_{1537}x^3
C1(x)=a1+a513x+a1025x2+a1537x3
.
.
.
.
.
.
.
.......
.......
C
511
(
x
)
=
a
511
+
a
1023
x
+
a
1535
x
2
+
a
2047
x
3
C_{511}(x)=a_{511}+a_{1023}x+a_{1535}x^2+a_{2047}x^3
C511(x)=a511+a1023x+a1535x2+a2047x3
再设
k
<
n
p
=
2048
4
=
512
k<\frac{n}{p}=\frac{2048}{4}=512
k<pn=42048=512,把
w
n
k
=
w
2048
k
w_n^k=w_{2048}^k
wnk=w2048k作为
x
x
x值代入
A
(
x
)
A(x)
A(x)多项式,有:
A
(
w
2048
k
)
=
C
0
(
w
2048
512
k
)
+
w
2048
k
C
1
(
w
2048
512
k
)
+
.
.
.
+
w
2048
511
k
C
511
(
w
2048
512
k
)
=
C
0
(
w
4
k
)
+
w
2048
k
C
1
(
w
4
k
)
+
.
.
.
+
w
2048
511
k
C
511
(
w
4
k
)
A(w_{2048}^k)=C_0(w_{2048}^{512k})+w_{2048}^kC_1(w_{2048}^{512k})+...+w_{2048}^{511k}C_{511}(w_{2048}^{512k})\\ =C_0(w_4^k)+w_{2048}^kC_1({w_4^k})+...+w_{2048}^{511k}C_{511}(w_4^k)
A(w2048k)=C0(w2048512k)+w2048kC1(w2048512k)+...+w2048511kC511(w2048512k)=C0(w4k)+w2048kC1(w4k)+...+w2048511kC511(w4k)
因此,根据单位根的性质,只需计算分别取
k
=
0
,
1
,
2
,
3
k=0,1,2,3
k=0,1,2,3时相应的
C
0
(
w
4
k
)
,
C
1
(
w
4
k
)
,
.
.
.
,
C
511
(
w
4
k
)
C_0(w_4^k),C_1({w_4^k}),...,C_{511}(w_4^k)
C0(w4k),C1(w4k),...,C511(w4k)值,即可很方便的计算取任意
k
<
512
k<512
k<512时的
A
(
w
n
k
)
A(w_n^k)
A(wnk)的值。
在Halo代码中,
w
n
1
w_n^1
wn1对应为omega
,
w
4
k
w_4^k
w4k对应为omega_step
(let omega_step = omega.pow(&[(j as u64) << log_new_n, 0, 0, 0]); //power_mod(omega, j*2^9,p)
),
w
n
k
w_n^k
wnk对应为omega_j
。
Halo代码中,tmp
为
2
2
×
2
9
2^2\times 2^9
22×29矩阵:
t
m
p
[
0
]
=
[
C
0
(
w
4
0
)
,
w
2048
(
0
∗
1
)
C
1
(
w
4
0
)
,
.
.
.
,
w
2048
(
0
∗
511
)
C
511
(
w
4
0
)
]
tmp[0]=[C_0(w_4^0),w_{2048}^{(0*1)}C_1(w_4^0),...,w_{2048}^{(0*511)}C_{511}(w_4^0)]
tmp[0]=[C0(w40),w2048(0∗1)C1(w40),...,w2048(0∗511)C511(w40)]
t
m
p
[
1
]
=
[
C
0
(
w
4
1
)
,
w
2048
(
1
∗
1
)
C
1
(
w
4
1
)
,
.
.
.
,
w
2048
(
1
∗
511
)
C
511
(
w
4
1
)
]
tmp[1]=[C_0(w_4^1),w_{2048}^{(1*1)}C_1(w_4^1),...,w_{2048}^{(1*511)}C_{511}(w_4^1)]
tmp[1]=[C0(w41),w2048(1∗1)C1(w41),...,w2048(1∗511)C511(w41)]
t
m
p
[
2
]
=
[
C
0
(
w
4
2
)
,
w
2048
(
2
∗
1
)
C
1
(
w
4
2
)
,
.
.
.
,
w
2048
(
2
∗
511
)
C
511
(
w
4
2
)
]
tmp[2]=[C_0(w_4^2),w_{2048}^{(2*1)}C_1(w_4^2),...,w_{2048}^{(2*511)}C_{511}(w_4^2)]
tmp[2]=[C0(w42),w2048(2∗1)C1(w42),...,w2048(2∗511)C511(w42)]
t
m
p
[
3
]
=
[
C
0
(
w
4
3
)
,
w
2048
(
3
∗
1
)
C
1
(
w
4
3
)
,
.
.
.
,
w
2048
(
3
∗
511
)
C
511
(
w
4
3
)
]
tmp[3]=[C_0(w_4^3),w_{2048}^{(3*1)}C_1(w_4^3),...,w_{2048}^{(3*511)}C_{511}(w_4^3)]
tmp[3]=[C0(w43),w2048(3∗1)C1(w43),...,w2048(3∗511)C511(w43)]
2.3.2 串行FFT算法serial_fft
// Perform sub-FFT
serial_fft(tmp, new_omega, log_new_n);
//此处的tmp为tmp[j]第j行内容,new_omega为2^9-th primitive root of unity,log_new_n值为9。
以下代码段的作用是将多项式系数数组
[
a
0
,
a
1
,
a
2
,
.
.
.
,
a
n
]
[a_0,a_1,a_2,...,a_n]
[a0,a1,a2,...,an]按奇偶重新排列,数组的前半段为偶数系数,后半段为奇数系数,具体为:
[
a
0
,
a
2
,
a
4
,
.
.
.
,
a
n
−
2
,
a
1
,
a
3
,
a
5
,
.
.
.
,
a
n
−
1
]
[a_0,a_2,a_4,...,a_{n-2},a_1,a_3,a_5,...,a_{n-1}]
[a0,a2,a4,...,an−2,a1,a3,a5,...,an−1]。
for k in 0..n {
let rk = bitreverse(k, log_n);
if k < rk {
a.swap(rk as usize, k as usize);
}
}
以
A
(
x
)
=
a
0
+
a
1
x
+
a
2
x
2
+
a
3
x
3
A(x)=a_0+a_1x+a_2x^2+a_3x^3
A(x)=a0+a1x+a2x2+a3x3为例,下面程序的演示效果如上图所示。
fn serial_fft<F: Field>(a: &mut [F], omega: F, log_n: u32) {
fn bitreverse(mut n: u32, l: u32) -> u32 {
let mut r = 0;
for _ in 0..l {
r = (r << 1) | (n & 1);
n >>= 1;
}
r
}
let n = a.len() as u32;
assert_eq!(n, 1 << log_n);
for k in 0..n {
let rk = bitreverse(k, log_n);
if k < rk {
a.swap(rk as usize, k as usize);
}
}
let mut m = 1;
for _ in 0..log_n {
let w_m = omega.pow(&[u64::from(n / (2 * m)), 0, 0, 0]);
let mut k = 0;
while k < n {
let mut w = F::one();
for j in 0..m {
let mut t = a[(k + j + m) as usize];
t *= w;
a[(k + j + m) as usize] = a[(k + j) as usize] - t;
a[(k + j) as usize] += t;
w *= w_m;
}
k += 2 * m;
}
m *= 2;
}
}
2.4 点值表示的多项式乘法运算
best_fft(&mut a, alpha, exp); //由系数表示转换为点值表示
best_fft(&mut b, alpha, exp);
// Multiply pairwise。点值表示的多项式乘法运算
let num_cpus = num_cpus::get();
if a.len() > num_cpus {
thread::scope(|scope| {
let chunk = a.len() / num_cpus::get();
for (a, b) in a.chunks_mut(chunk).zip(b.chunks(chunk)) {
scope.spawn(move |_| {
for (a, b) in a.iter_mut().zip(b.iter()) {
*a *= *b;
}
});
}
})
.unwrap();
} else {
for (a, b) in a.iter_mut().zip(b.iter()) {
*a *= *b;
}
}
2.5 傅里叶逆变换IFFT
IFFT的作用是将点值表示转换为系数表示。
// Inverse FFT
let alpha_inv = alpha.invert().unwrap();
best_fft(&mut a, alpha_inv, exp);
// Divide all elements by m = a.len()
let minv = F::from_u64(m as u64).invert().unwrap();
if a.len() > num_cpus {
thread::scope(|scope| {
let chunk = a.len() / num_cpus::get();
for a in a.chunks_mut(chunk) {
scope.spawn(move |_| {
for a in a.iter_mut() {
*a *= minv;
}
});
}
})
.unwrap();
} else {
for a in a.iter_mut() {
*a *= minv;
}
}
//为了递归调用,a数组的长度做了扩展补零。只截取乘积后相应的阶即可。
a.truncate(coeffs_of_result);
2.6 补充资料
sage: root=155978335310571138272812138773814534618935470879470300630834867870977
....: 67520449
sage: p=417793508166910149535221561915641187328610041455049381805950185075960478
....: 43329
sage: omega= power_mod(root,2^21,p)//21=32-exp, for exp=11.
sage: omega
35014335792849108923302692126549442116295992392289760687159465394416590439942
sage: hex(omega)
'4d696968d9c7e5b55e6a88fe57cbaa9e166872f777629c2cd200ba70d7cec606'
sage: new_omega=power_mod(omega,4,p)
sage: hex(new_omega)
'323286f8bcf390a8b2be7ef037eab34c127cf6fc0b1e0bc4866f1ba33bc1dc80'
//对应的new_omega为2^{exp-log_n}即2^9-th primitive root of unity。
sage: power_mod(new_omega, 2^9,p)
1
sage: R=2^256
sage: mod(R,p) //域内的one值表示,之所以用R来代替1值,是因为采用了montgomery_reduce表示,1*R=1, montgomery_reduce后R/R=1。
32233387603934165516526672625559670387547976374630687678267546992721033953278
sage: hex(3223338760393416551652667262555967038754797637463068767826754699272103
....: 3953278)
'4743736b947db12c8a7ab15117a98d9efc82c7cb9bfdb6facc000305fffffffe'
对用地
参考资料:
[1] 论文《Halo: Recursive Proof Composition without a Trusted Setup》
[2] https://electriccoin.co/zh/blog/halo-recursive-proof-composition-without-a-trusted-setup/
[3] https://github.com/ebfull/halo