刚开始学隐私计算,读到SIRNN,感觉真的好难好难,门槛比deep learning高好多,先尽量啃一啃(捂脸.jpg)。
1 文章及代码
Paper: SIRNN: A Math Library for Secure RNN Inference
Code: https://github.com/mpc-msri/EzPC.
2 主要贡献
- 为数学函数(指数、sigmoid、tanh、平方根倒数)提出了全新的密码学友好的新近似。
- 为不均匀(混合)的bitwidth提供2PC协议,实现高效的数学函数。
- SIRNN首次为RNN和CNN提供了安全推理库,在延迟、通信等方面达到SOTA,并拥有高数值准确率。
3 概览
3.1 Scale和bitwidth
2PC在整数上运算比在浮点数上运算更高效,在定点运算数中, ⌊ r 2 s ⌋ m o d 2 l \lfloor r2^s\rfloor \mod 2^l ⌊r2s⌋mod2l,其中 l l l就是bitwidth, s s s是scale。
3.2 对数学函数的近似
- 首先用lookup table (LUT)得到一个不错的初始化近似,然后用迭代算法提升这个近似。
- 更大的LUT近似结果更准确,但是通信开销线性增长。
- 对于指数和负数输入,分解输入x到更小的子串(digit decomposition)。
- 为了让迭代算法更高效,本文采用定点(fixed-point)算数及不均匀的混合bitwidth。
3.3 SIRNN协议
安全参数:
λ
=
128
\lambda=128
λ=128
基于4种构造块:
(1)Extension(扩展)
Z
2
m
→
Z
2
n
(
m
<
n
)
\mathbb{Z}_{2^m} \rightarrow \mathbb{Z}_{2^n} (m<n)
Z2m→Z2n(m<n)
GC需要的通信开销(重构和重共享)为:
λ
(
4
m
+
2
n
)
\lambda(4m+2n)
λ(4m+2n) bits,SIRNN需要的通信开销仅为:
λ
m
\lambda m
λm,大约比GC快6x。
(2)Truncation(截断)
常用于乘法之后减小规模,对于
l
l
l-bit截断了
s
s
s-bit有四种截断操作:
- 逻辑右移(保留位宽)
- 算数右移(保留位宽)
- 截断且减小(输出截断值 Z 2 l − s \mathbb Z_{2^{l-s}} Z2l−s)
- 除以 2 s 2^s 2s
目前最好的算数右移通信大约是:
λ
(
l
+
s
)
\lambda(l+s)
λ(l+s),本文提出的逻辑/算数右移协议大约是
λ
l
\lambda l
λl,大多数数学函数都只需要截断且减小去减小scale和bitwidth,SIRNN只需要
λ
(
s
+
1
)
\lambda (s+1)
λ(s+1)通信。
(3)Multiplication(乘法)
m
m
m-bit整数和
n
n
n-bit整数相乘得到
l
=
(
m
+
n
)
l=(m+n)
l=(m+n)-bit输出,
l
l
l的选择保证了没有溢出。
(4)Digit Decomposition(数位分解)
将
l
l
l-bit的值分解为
c
=
l
/
d
c=l/d
c=l/d个
d
d
d-bits,可以用GC实现,通信量为
λ
(
6
l
−
2
c
−
2
)
\lambda (6l-2c-2)
λ(6l−2c−2) bits。本文进一步优化,通信量为
λ
(
c
−
1
)
(
d
+
2
)
\lambda (c-1)(d+2)
λ(c−1)(d+2) bits,大约比GC低5x。
4 前提知识
4.1 ULP误差(units in last place)
ULP是真实数据和函数输出值之间的可表示数值的数量。
4.2 威胁模型
- 两方安全计算(2PC)
- 静态的半诚实攻击:遵循协议,但是会学习额外信息
4.3 符号表示
符号 | 意义 |
---|---|
x ∈ Z 2 l x\in \mathbb Z_{2^l} x∈Z2l | power-of-2 rings, x x x的环为 Z 2 l \mathbb Z_{2^l} Z2l,即以 2 l 2^l 2l为模 |
B B B | ring Z 2 \mathbb Z_2 Z2,即以2为模 |
λ \lambda λ | 计算安全系数 |
⊕ \oplus ⊕ | 异或门 |
ζ l , ζ l , m ( m > l ) \zeta_l, \zeta_{l,m} (m>l) ζl,ζl,m(m>l) | 无损lifting操作,映射 Z L → Z \mathbb Z_L\rightarrow \mathbb Z ZL→Z,映射 Z L → Z M \mathbb Z_L\rightarrow \mathbb Z_M ZL→ZM |
L , M , N L,M,N L,M,N | 2 l , 2 m , 2 n 2^l, 2^m, 2^n 2l,2m,2n |
[ k ] [k] [k] | 0 , 1 , . . , k − 1 {0, 1, .., k-1} 0,1,..,k−1 |
1 { b } 1\{b\} 1{b} | b = t r u e b=true b=true时为1,反之为0 |
i n t ( x ) int(x) int(x)和 u i n t ( x ) uint(x) uint(x) | 对于 x ∈ Z l x\in \mathbb Z^l x∈Zl,分别代表有符号和无符号值,int(x)=uint(x)−MSB(x)L |
MSB(x) | MSB(x) = 1 { x ≥ 2 l − 1 } =1\{x\geq 2^{l-1}\} =1{x≥2l−1},表示最有效高位 |
F M i l l l ( x , y ) F_{Mill}^l(x, y) FMilll(x,y) | F M i l l l ( x , y ) = ⟨ z ⟩ B = 1 { x < y } F_{Mill}^l(x, y)=\langle z\rangle^B=1\{x<y\} FMilll(x,y)=⟨z⟩B=1{x<y} |
F w r a p l F_{wrap}^l Fwrapl | F w r a p l = F M i l l l ( L − 1 − x , y ) : w = w r a p ( x , y , L ) = 1 { x + y ≥ L } F_{wrap}^l=F_{Mill}^l(L-1-x, y): w=wrap(x, y, L)=1\{x+y\geq L\} Fwrapl=FMilll(L−1−x,y):w=wrap(x,y,L)=1{x+y≥L} |
e e e | e = 1 { ( x + y m o d L ) = L − 1 } e=1\{(x+y \mod L)=L-1\} e=1{(x+ymodL)=L−1},判断是否全是1 |
F w r a p & a l l 1 s l F_{wrap\&all1s}^l Fwrap&all1sl | F w r a p & a l l 1 s l ( x , y ) = ( ⟨ w ⟩ B ∣ ∣ ⟨ e ⟩ B ) F_{wrap\&all1s}^l(x,y)=(\langle w\rangle^B||\langle e\rangle^B) Fwrap&all1sl(x,y)=(⟨w⟩B∣∣⟨e⟩B),至多一项是1 |
∗ m *_m ∗m | x ∗ m y = x y m o d M x*_m y=xy\mod M x∗my=xymodM,从 Z × Z → Z M \mathbb Z \times \mathbb Z \rightarrow \mathbb Z_M Z×Z→ZM |
l l l | bitwidth |
s s s | scale |
l − s l-s l−s | 整数部分的bitwidth |
F i x ( x , l , s ) Fix(x, l, s) Fix(x,l,s) | F i x ( x , l , s ) = x 2 s m o d L Fix(x, l, s)=x2^s \mod L Fix(x,l,s)=x2smodL,从实数转到定点数表示 |
u r t ( l , s ) ( a ) urt_{(l,s)}(a) urt(l,s)(a) | 对于无符号数, u r t ( l , s ) ( a ) = u i n t ( a ) / 2 s urt_{(l,s)}(a)=uint(a)/2^s urt(l,s)(a)=uint(a)/2s,从定点数转到实数表示 |
s r t ( l , s ) ( a ) srt_{(l,s)}(a) srt(l,s)(a) | 对于有符号数, s r t ( l , s ) ( a ) = i n t ( a ) / 2 s srt_{(l,s)}(a)=int(a)/2^s srt(l,s)(a)=int(a)/2s,从定点数转到实数表示 |
> > L , > > A >>_L, >>_A >>L,>>A | 逻辑右移和算术右移 |
4.4 密码学基础
- 秘密共享(SS)
2-out-of-2加性秘密共享: x = ⟨ x ⟩ 0 l + ⟨ x ⟩ 1 l m o d L x=\langle x\rangle_0^l+\langle x\rangle_1^l \mod L x=⟨x⟩0l+⟨x⟩1lmodL。 - 不经意传输(OT)
1-out-of-k OT,用OT Extension (OTE)实现,并用了Correlated OT (COT)。
4.5 2PC基本函数
- 百万富翁/wrap
F M i l l l = 1 { x < y } F_{Mill}^l=1\{x<y\} FMilll=1{x<y},CrypTFlow2中通信量低于 λ l + 14 l \lambda l+14l λl+14l bits和 log l \log l logl rounds。
F w r a p l = F M i l l l ( L − 1 − x , y ) : w = w r a p ( x , y , L ) = 1 { x + y ≥ L } F_{wrap}^l=F_{Mill}^l(L-1-x, y): w=wrap(x, y, L)=1\{x+y\geq L\} Fwrapl=FMilll(L−1−x,y):w=wrap(x,y,L)=1{x+y≥L} - AND
输入 ⟨ x ⟩ B , ⟨ y ⟩ B \langle x\rangle^B, \langle y\rangle^B ⟨x⟩B,⟨y⟩B,输出 ⟨ x ∧ y ⟩ B \langle x \land y\rangle^B ⟨x∧y⟩B,用Beaver bit-triples实现,CrypTFlow2中通信量为 λ + 20 \lambda+20 λ+20。 - Boolean to Arithmetic (B2A)
输入boolean share,输出相同值的算术share,采用COT协议实现,通信量为 λ + l \lambda+l λ+l bits。 - Multiplexer (MUX)
⟨ x ⟩ B \langle x\rangle^B ⟨x⟩B和 ⟨ y ⟩ l \langle y\rangle^l ⟨y⟩l作为输入,输出 ⟨ z ⟩ l \langle z\rangle^l ⟨z⟩l,如果 x = 1 x=1 x=1,则 z = y z=y z=y,反之同理。本文提出的协议将通信量从 2 ( λ + 2 l ) 2(\lambda+2l) 2(λ+2l)(CrypTFlow2)降到 2 ( λ + l ) 2(\lambda+l) 2(λ+l)。 - Lookup Table (LUT)
对于表 T T T, M M M个入口,每个 n n n-bits,输入 ⟨ x ⟩ m \langle x\rangle^m ⟨x⟩m, ⟨ z ⟩ n \langle z\rangle^n ⟨z⟩n,满足 z = T [ x ] z=T[x] z=T[x]。可以用1-out-of-m OT实现,通信量为 2 λ + M n 2\lambda +Mn 2λ+Mn bits。这是个查表的操作,输入和输出的位数是不同的。
5 构建块协议
5.1 零扩展和有符号扩展
对于
m
m
m-bit的数
x
∈
Z
M
x\in \mathbb Z_M
x∈ZM,将其转换为
n
n
n-bit的数(
n
>
m
n>m
n>m),这个过程就称为扩展(extension)。零扩展和有符号扩展分别用于扩展无符号数和有符号数的位宽。
零扩展(Zero Extension)
P
0
P_0
P0和
P
1
P_1
P1两方输入
⟨
x
⟩
m
\langle x\rangle^m
⟨x⟩m,扩展输出
⟨
y
⟩
n
\langle y\rangle^n
⟨y⟩n,要求满足
u
n
i
t
(
x
)
=
u
i
n
t
(
y
)
unit(x)=uint(y)
unit(x)=uint(y)。对于
x
m
∈
Z
M
x^m\in \mathbb Z_M
xm∈ZM,可以得到 【问:这个等式在后面广泛使用,没太理解怎么来的】【答:其实
−
w
M
-wM
−wM就是实现的
m
o
d
M
\mod M
modM计算过程,防止求和在
Z
2
m
\mathbb Z_{2^m}
Z2m环上溢出】:
x
m
=
⟨
x
⟩
0
m
+
⟨
x
⟩
1
m
−
w
M
x^m = \langle x \rangle_0^m+\langle x \rangle_1^m-wM
xm=⟨x⟩0m+⟨x⟩1m−wM
其中,
w
=
w
r
a
p
(
⟨
x
⟩
0
m
,
⟨
x
⟩
1
m
,
M
)
w=wrap(\langle x \rangle_0^m, \langle x \rangle_1^m, M)
w=wrap(⟨x⟩0m,⟨x⟩1m,M),这是个boolean share,需要转换为算术share。这里考虑在
n
−
m
n-m
n−m环上转换,原因就是下面的模约减步骤会使通信量大大降低。
F
B
2
A
n
−
m
(
⟨
w
⟩
B
)
=
⟨
w
⟩
n
−
m
∈
Z
2
n
−
m
F_{B2A}^{n-m}(\langle w\rangle^B)=\langle w\rangle^{n-m}\in \mathbb Z_{2^{n-m}}
FB2An−m(⟨w⟩B)=⟨w⟩n−m∈Z2n−m
w = ⟨ w ⟩ 0 n − m + ⟨ w ⟩ 1 n − m − w r a p ( ⟨ w ⟩ 0 n − m , ⟨ w ⟩ 1 n − m , Z 2 n − m ) 2 n − m w = \langle w\rangle_0^{n-m} + \langle w\rangle_1^{n-m}-wrap(\langle w\rangle_0^{n-m}, \langle w\rangle_1^{n-m}, \mathbb Z_{2^{n-m}})2^{n-m} w=⟨w⟩0n−m+⟨w⟩1n−m−wrap(⟨w⟩0n−m,⟨w⟩1n−m,Z2n−m)2n−m
M ∗ n w = M ∗ n ( ⟨ w ⟩ 0 n − m + ⟨ w ⟩ 1 n − m − w r a p ( ⟨ w ⟩ 0 n − m , ⟨ w ⟩ 1 n − m , Z 2 n − m ) 2 n − m ) M_{*n}w = M_{*n}(\langle w\rangle_0^{n-m} + \langle w\rangle_1^{n-m} - wrap(\langle w\rangle_0^{n-m}, \langle w\rangle_1^{n-m}, \mathbb Z_{2^{n-m}})2^{n-m}) M∗nw=M∗n(⟨w⟩0n−m+⟨w⟩1n−m−wrap(⟨w⟩0n−m,⟨w⟩1n−m,Z2n−m)2n−m)
其中,
M
∗
n
w
r
a
p
(
⋅
)
2
n
−
m
=
M
w
r
a
p
(
⋅
)
2
n
−
m
m
o
d
N
=
w
r
a
p
(
⋅
)
2
n
m
o
d
N
=
0
M_{*n}wrap(\cdot)2^{n-m}=Mwrap(\cdot)2^{n-m} \mod N=wrap(\cdot)2^{n} \mod N=0
M∗nwrap(⋅)2n−m=Mwrap(⋅)2n−mmodN=wrap(⋅)2nmodN=0(这一步称作“模约减”,modulo-reduce),所以上式子转换为:
M
∗
n
w
=
M
∗
n
(
⟨
w
⟩
0
n
−
m
+
⟨
w
⟩
1
n
−
m
)
M_{*n}w = M_{*n}(\langle w\rangle_0^{n-m} + \langle w\rangle_1^{n-m})
M∗nw=M∗n(⟨w⟩0n−m+⟨w⟩1n−m)
于是:
y
=
∑
b
=
0
1
(
⟨
x
⟩
b
m
−
M
⟨
w
⟩
b
n
−
m
)
m
o
d
N
y = \sum_{b=0}^1(\langle x\rangle_b^m-M\langle w\rangle_b^{n-m}) \mod N
y=b=0∑1(⟨x⟩bm−M⟨w⟩bn−m)modN
这里是在
P
0
P_0
P0和
P
1
P_1
P1上分别计算,然后求和取模,得到扩展后的结果。其中,
x
m
o
d
N
=
y
x \mod N=y
xmodN=y。
算法如下:
需要
log
(
m
+
2
)
\log(m+2)
log(m+2) rounds和少于
λ
(
m
+
1
)
+
13
m
+
n
\lambda(m+1)+13m+n
λ(m+1)+13m+n bits的通信量。作为对比,用GC实现零扩展和有符号扩展需要
λ
(
4
m
+
2
n
−
4
)
\lambda(4m+2n-4)
λ(4m+2n−4) bits的通信量,大约是SIRNN的6倍。
有符号扩展(Signed Extension)
有符号扩展可以基于以下等式,通过转换无符号扩展得到,在环
Z
\mathbb Z
Z上:
i
n
t
(
x
)
=
x
′
−
2
m
−
1
,
x
′
=
x
+
2
m
−
1
m
o
d
M
int(x)=x'-2^{m-1}, x'=x+2^{m-1} \mod M
int(x)=x′−2m−1,x′=x+2m−1modM
证明如下:
于是:
S
E
x
t
(
x
,
m
,
n
)
=
Z
E
x
t
(
x
,
m
,
n
)
−
2
m
−
1
SExt(x, m, n)=ZExt(x, m, n)-2^{m-1}
SExt(x,m,n)=ZExt(x,m,n)−2m−1
相比零扩展,没有额外的通信开销。
5.2 截断
首先,规定
>
>
L
,
>
>
A
>>_L,>>_A
>>L,>>A分别表示逻辑右移和算术右移,它们的输入和输出都是在
Z
L
\mathbb Z_L
ZL环上。
T
R
(
x
,
s
)
TR(x, s)
TR(x,s)表示截断且减小(truncate & reduce),将
x
∈
Z
L
x\in \mathbb Z_L
x∈ZL截断且减小
s
s
s-bits,最终得到的
x
x
x在更小的
Z
2
l
−
s
\mathbb Z_{2^{l-s}}
Z2l−s环上。
逻辑右移
Toy example:
x
=
101001
x=101001
x=101001逻辑右移3位,则
x
′
=
000101
x'=000101
x′=000101(右侧截掉,左侧补0)。
对于
x
∈
Z
L
x\in \mathbb Z_L
x∈ZL,则
x
=
⟨
x
⟩
0
l
+
⟨
x
⟩
1
l
m
o
d
L
x=\langle x\rangle_0^l+\langle x\rangle_1^l \mod L
x=⟨x⟩0l+⟨x⟩1lmodL,记
⟨
x
⟩
b
l
=
u
b
∣
∣
v
b
\langle x\rangle_b^l=u_b||v_b
⟨x⟩bl=ub∣∣vb(
u
b
u_b
ub是高位,
v
b
v_b
vb是低位),其中
u
b
∈
{
0
,
1
}
l
−
s
,
v
b
∈
{
0
,
1
}
s
u_b\in\{0, 1\}^{l-s}, v_b\in\{0, 1\}^{s}
ub∈{0,1}l−s,vb∈{0,1}s。如下图:
根据前面提到的公式:
x
m
=
⟨
x
⟩
0
m
+
⟨
x
⟩
1
m
−
w
M
x^m = \langle x \rangle_0^m+\langle x \rangle_1^m-wM
xm=⟨x⟩0m+⟨x⟩1m−wM
可以得到:
x
>
>
L
s
=
u
0
+
u
1
−
2
l
−
s
w
r
a
p
(
⟨
x
⟩
0
l
,
⟨
x
⟩
1
l
,
L
)
+
w
r
a
p
(
v
0
,
v
1
,
2
s
)
x>>_Ls=u_0+u_1-2^{l-s} wrap (\langle x\rangle_0^l, \langle x\rangle_1^l, L) + wrap(v_0, v_1, 2^s)
x>>Ls=u0+u1−2l−swrap(⟨x⟩0l,⟨x⟩1l,L)+wrap(v0,v1,2s)
上式中,
w
r
a
p
(
v
0
,
v
1
,
2
s
)
wrap(v_0, v_1, 2^s)
wrap(v0,v1,2s)这一项是考虑了进位。我们知道,加性秘密共享时,
v
v
v部分可能会存在1位进位的情况,所以
w
r
a
p
(
v
0
,
v
1
,
2
s
)
wrap(v_0, v_1, 2^s)
wrap(v0,v1,2s)就是判断
v
0
+
v
1
v_0+v_1
v0+v1是否大于
2
s
2^s
2s,如果是,则会进1,如果不是,则为0。
常规做法是计算两个
w
r
a
p
(
⋅
)
wrap(\cdot)
wrap(⋅)值即可,但是SIRNN提出了一种优化,避开直接计算位宽是
l
l
l的那一项。文章中的Lemma 1即是这个引理:
通信开销低于
λ
(
l
+
3
)
+
15
+
s
+
20
\lambda(l+3)+15+s+20
λ(l+3)+15+s+20,并需要
log
l
+
3
\log l+3
logl+3 rounds。
原文证明如下:
算法如下:
算术右移
对于无符号数,直接采用逻辑右移,对于有符号数,则需要采用算术右移。从前面零扩展到有符号扩展可以知道:
i
n
t
(
x
)
=
x
′
−
2
l
−
1
,
x
′
=
x
+
2
l
−
1
m
o
d
L
int(x)=x'-2^{l-1}, x'=x+2^{l-1} \mod L
int(x)=x′−2l−1,x′=x+2l−1modL,于是:
x
>
>
A
s
=
x
>
>
L
s
−
2
l
−
s
−
1
x>>_As = x>>_Ls-2^{l-s-1}
x>>As=x>>Ls−2l−s−1
截断且减小
Toy example:
x
=
101001
x=101001
x=101001截断且减小3位,则
x
′
=
101
x'=101
x′=101。
因为
2
l
−
s
∗
l
w
m
o
d
2
l
−
s
=
0
2^{l-s}{*_l} w \mod 2^{l-s}=0
2l−s∗lwmod2l−s=0(模约减),所以:
⟨
T
R
(
x
,
s
)
⟩
l
−
s
=
u
0
+
u
1
+
w
r
a
p
(
v
0
,
v
1
,
2
s
)
\langle TR(x, s)\rangle^{l-s}=u_0+u_1+wrap(v_0, v_1, 2^s)
⟨TR(x,s)⟩l−s=u0+u1+wrap(v0,v1,2s)
除以power-of-2
z
<
0
,
z
=
⌈
i
n
t
(
x
)
/
2
s
⌉
m
o
d
L
;
z
≥
0
,
z
=
⌊
i
n
t
(
x
)
/
2
s
⌋
m
o
d
L
z<0, z=\lceil int(x)/2^s\rceil \mod L; z\geq0, z=\lfloor int(x)/2^s\rfloor \mod L
z<0,z=⌈int(x)/2s⌉modL;z≥0,z=⌊int(x)/2s⌋modL
实际上
i
n
t
(
x
)
/
2
s
m
o
d
L
int(x)/2^s \mod L
int(x)/2smodL就是做
>
>
A
>>_A
>>A,取整括号即是将值往0靠近。令
m
x
=
1
{
x
≥
2
l
−
1
}
m_x=1\{x\geq 2^{l-1}\}
mx=1{x≥2l−1}判断
x
x
x的正负性,
c
=
1
{
x
m
o
d
2
s
=
0
}
c=1\{x\mod 2^s=0\}
c=1{xmod2s=0}
m
x
=
1
m_x=1
mx=1,则
z
<
0
,
⌈
z
⌉
z<0, \lceil z\rceil
z<0,⌈z⌉;反之,
⌊
z
⌋
\lfloor z\rfloor
⌊z⌋。所以有:
D
i
v
P
o
w
2
(
x
,
s
)
=
(
x
>
>
A
s
)
+
m
x
∧
c
DivPow2(x, s)=(x>>_As)+m_x\land c
DivPow2(x,s)=(x>>As)+mx∧c
5.3 混合位宽乘法
以前做乘法通常是用Beaver Triplet三元组实现,SIRNN中不能用了,因为加法和乘法的数bitwidth不一致。
无符号乘法
输入
⟨
x
⟩
m
,
⟨
y
⟩
n
\langle x\rangle^m, \langle y\rangle^n
⟨x⟩m,⟨y⟩n,输出
⟨
z
⟩
l
,
z
=
x
∗
l
y
,
l
=
n
+
m
\langle z\rangle^l, z=x*_l y, l=n+m
⟨z⟩l,z=x∗ly,l=n+m。
对于
x
,
y
x,y
x,y,在
Z
\mathbb Z
Z上有:
u
i
n
t
(
x
)
⋅
u
i
n
t
(
y
)
=
(
x
0
+
x
1
−
2
m
w
x
)
⋅
(
y
0
+
y
1
−
2
n
w
y
)
=
x
0
y
0
+
x
0
y
1
+
x
1
y
0
+
x
1
y
1
−
2
m
w
x
y
−
2
n
w
y
x
+
2
l
w
x
w
y
uint(x)\cdot uint(y)=(x_0+x_1-2^mw_x)\cdot(y_0+y_1-2^nw_y)\\=x_0y_0+x_0y_1+x_1y_0+x_1y_1-2^mw_xy-2^nw_yx+2^lw_xw_y
uint(x)⋅uint(y)=(x0+x1−2mwx)⋅(y0+y1−2nwy)=x0y0+x0y1+x1y0+x1y1−2mwxy−2nwyx+2lwxwy
观察上式,
x
0
y
0
,
x
1
y
1
x_0y_0,x_1y_1
x0y0,x1y1都是可以本地计算的【本地计算为什么不管位宽是否一致?】,
2
l
w
x
w
y
2^lw_xw_y
2lwxwy可以在
m
o
d
L
\mod L
modL时被消掉(模约减),
w
x
y
,
x
y
x
w_xy, x_yx
wxy,xyx是boolean share和算术share的计算,本质上是MUX,可用直接用OT实现。最难的一项是交叉项
x
0
y
1
,
x
1
y
0
x_0y_1, x_1y_0
x0y1,x1y0,SIRNN采用COT实现。
巧妙的一点在于:选择比特位短的一方作为receiver,比特位长的一方作为sender,这样在做OT的取数时,round数就会更少。
交叉项算法如下:
无符号乘法算法如下:
SIRNN利用1-out-of-2的COT来实现这个过程,将短的数按位拆解,每一位非0即1,然后做二选一的COT,每一位计算完成后,在本地累加起来。
通信开销大约是:
λ
(
3
μ
+
v
)
+
μ
(
μ
+
2
v
)
+
16
(
m
+
n
)
\lambda(3\mu + v) + \mu(\mu + 2v) + 16(m + n)
λ(3μ+v)+μ(μ+2v)+16(m+n),其中
μ
=
min
(
m
,
n
)
,
ν
=
max
(
m
,
n
)
\mu = \min(m, n), ν = \max(m, n)
μ=min(m,n),ν=max(m,n)。普通的扩展位数然后相乘的开销是:
3
λ
(
μ
+
v
)
+
(
m
+
n
)
2
+
15
(
m
+
n
)
3\lambda(\mu+v)+(m+n)^2+15(m + n)
3λ(μ+v)+(m+n)2+15(m+n),大约是SIRNN的1.5x。
有符号乘法
布尔分享转换为算术分享:
⟨
x
⟩
A
=
⟨
x
⟩
0
B
+
⟨
x
⟩
1
B
−
2
⟨
x
⟩
0
B
⟨
x
⟩
1
B
\langle x\rangle^A=\langle x\rangle_0^B+\langle x\rangle_1^B-2\langle x\rangle_0^B\langle x\rangle_1^B
⟨x⟩A=⟨x⟩0B+⟨x⟩1B−2⟨x⟩0B⟨x⟩1B
基于前面无符号数和有符号数的关系,可以得到:无符号数
x
′
=
x
+
2
m
−
1
m
o
d
M
,
y
′
=
y
+
2
n
−
1
m
o
d
N
x'=x+2^{m-1}\mod M, y'=y+2^{n-1}\mod N
x′=x+2m−1modM,y′=y+2n−1modN。由秘密共享,
x
′
=
x
0
′
+
x
1
′
m
o
d
M
,
y
′
=
y
0
′
+
y
1
′
m
o
d
N
x'=x_0'+x_1' \mod M, y'=y_0'+y_1' \mod N
x′=x0′+x1′modM,y′=y0′+y1′modN。有符号数
i
n
t
(
x
)
=
x
′
−
2
m
−
1
,
i
n
t
(
y
)
=
y
′
−
2
n
−
1
int(x)=x'-2^{m-1}, int(y)=y'-2^{n-1}
int(x)=x′−2m−1,int(y)=y′−2n−1。因此,在
Z
\mathbb Z
Z环上:
x
′
y
′
x'y'
x′y′是无符号数的乘法,可以用algorithm 3计算,
2
m
−
1
y
b
′
,
2
n
−
1
x
b
′
2^{m-1}y_b', 2^{n-1}x_b'
2m−1yb′,2n−1xb′也都可以在本地计算出来。难点是wrap项应该如何计算。
2
m
+
n
−
1
w
x
′
=
2
l
−
1
w
x
′
=
2
l
−
1
(
⟨
w
x
′
⟩
0
B
+
⟨
w
x
′
⟩
1
B
−
2
⟨
w
x
′
⟩
0
B
⟨
w
x
′
⟩
1
B
)
2^{m+n-1}w_{x'}=2^{l-1}w_{x'}=2^{l-1}(\langle w_{x'}\rangle_0^B+\langle w_{x'}\rangle_1^B-2\langle w_{x'}\rangle_0^B\langle w_{x'}\rangle_1^B)
2m+n−1wx′=2l−1wx′=2l−1(⟨wx′⟩0B+⟨wx′⟩1B−2⟨wx′⟩0B⟨wx′⟩1B)
其中,
2
⟨
w
x
′
⟩
0
B
⟨
w
x
′
⟩
1
B
2\langle w_{x'}\rangle_0^B\langle w_{x'}\rangle_1^B
2⟨wx′⟩0B⟨wx′⟩1B与
2
l
−
1
2^{l-1}
2l−1相乘再
m
o
d
L
\mod L
modL后会被消除掉,所以无需计算。因此,上式变为:
2
m
+
n
−
1
w
x
′
=
2
l
−
1
w
x
′
=
2
l
−
1
(
⟨
w
x
′
⟩
0
B
+
⟨
w
x
′
⟩
1
B
)
2^{m+n-1}w_{x'}=2^{l-1}w_{x'}=2^{l-1}(\langle w_{x'}\rangle_0^B+\langle w_{x'}\rangle_1^B)
2m+n−1wx′=2l−1wx′=2l−1(⟨wx′⟩0B+⟨wx′⟩1B)
有符号的乘法相比无符号的乘法,也没有额外的开销。
矩阵乘法和卷积
矩阵乘法和卷积是很常见的(实际上可以展开为普通乘法做elment-wise乘和加),两个矩阵
A
∈
Z
M
d
1
×
d
2
,
A
∈
Z
N
d
2
×
d
3
A\in \mathbb Z_M^{d1\times d2}, A\in \mathbb Z_N^{d2\times d3}
A∈ZMd1×d2,A∈ZNd2×d3,输出矩阵乘法结果
A
∈
Z
L
d
1
×
d
3
A\in \mathbb Z_L^{d1\times d3}
A∈ZLd1×d3,其中
l
=
m
+
n
l=m+n
l=m+n。做矩阵乘法需要
d
2
d_2
d2次乘以及
d
2
−
1
d_2-1
d2−1次加。
这个时候可能出现的问题是:加法导致溢出。一种解决方式是将element-wise乘后的结果扩展
e
=
⌈
log
d
2
⌉
e=\lceil \log d_2\rceil
e=⌈logd2⌉-bits后,再做加法。但是,这样扩展开销很大,需要扩展
d
1
d
2
d
3
d_1d_2d_3
d1d2d3次。
于是本文这样做:考虑到前面算交叉项(CrossTerm)时,通信round数取决于较小的bitwidth,所以本文将bitwidth较大的一项拿去扩展
e
e
e-bits,在不增加开销的情况下,扩大了环。
通信开销大致为
λ
(
3
d
1
d
2
(
m
+
2
)
+
d
2
d
3
(
n
+
2
)
)
+
d
1
d
2
d
3
(
(
2
m
+
4
)
(
n
+
e
)
+
m
2
+
5
m
)
\lambda(3d_1d_2(m+2)+d_2d_3(n+2))+d_1d_2d_3((2m+4)(n+e)+m^2+5m)
λ(3d1d2(m+2)+d2d3(n+2))+d1d2d3((2m+4)(n+e)+m2+5m) bits。
算法如下:
乘且截断
首先调用有符号乘法,然后截断。输入
⟨
x
⟩
m
,
⟨
y
⟩
n
\langle x\rangle^m, \langle y\rangle^n
⟨x⟩m,⟨y⟩n,输出
⟨
z
′
⟩
l
−
s
\langle z'\rangle^{l-s}
⟨z′⟩l−s。
z
=
i
n
t
(
x
)
∗
l
i
n
t
(
y
)
,
z
′
=
T
R
(
z
,
s
)
z=int(x)*_l int(y), z'=TR(z, s)
z=int(x)∗lint(y),z′=TR(z,s)。其中
l
=
m
+
n
l=m+n
l=m+n。
5.4 数值分解和MSNZB (Most Significant Non-Zero Bit)
数值分解
将
l
l
l-bit的数分解为
c
c
c个长度为
d
=
l
/
c
d=l/c
d=l/c的子串或数值,使得
x
=
z
c
−
1
∣
∣
.
.
.
∣
∣
z
0
x=z_{c-1}||...||z_0
x=zc−1∣∣...∣∣z0。
算法如下:
MSNZB
返回最高非零比特的索引:比如
x
=
001010
x=001010
x=001010返回的就是3。
算法如下:
5.5 MSB to Wrap Optimization
本文大量依赖于 w = w r a p ( ⟨ x ⟩ 0 l , ⟨ x ⟩ 1 l , L ) w=wrap(\langle x\rangle_0^l, \langle x\rangle_1^l, L) w=wrap(⟨x⟩0l,⟨x⟩1l,L),一些情况下,我们能得到 m x = M S B ( x ) m_x=MSB(x) mx=MSB(x)或 ⟨ m x ⟩ B \langle m_x\rangle^B ⟨mx⟩B,于是 w = ( ( 1 ⊕ m x ) ∧ ( m 0 ⊕ m 1 ) ⊕ ( m 0 ∧ m 1 ) ) w=((1\oplus m_x)\land (m_0\oplus m_1)\oplus(m_0\land m_1)) w=((1⊕mx)∧(m0⊕m1)⊕(m0∧m1)),其中 m b = M S B ( ⟨ x ⟩ B l ) m_b=MSB(\langle x\rangle_B^l) mb=MSB(⟨x⟩Bl)。当 m x m_x mx是秘密分享时,使用 ( 4 1 ) \binom{4}{1} (14)-OT;当 m x m_x mx是明文时,使用 ( 2 1 ) \binom{2}{1} (12)-OT。
6 构建数学库
6.1 指数
求
r
E
x
p
(
z
)
=
e
−
z
,
z
∈
R
+
rExp(z)=e^{-z}, z\in \mathbb R^+
rExp(z)=e−z,z∈R+的值,首先将输入
x
x
x分成
k
k
k段,然后每段在LUT (Look Up Table)进行查表,将得到的结果相乘。
算法如下:
6.2 Sigmoid和Tanh
s
i
g
m
o
i
d
(
z
)
=
1
1
+
e
−
z
sigmoid(z)=\frac{1}{1+e^{-z}}
sigmoid(z)=1+e−z1,可以表示如下:
其中,
h
(
z
)
=
1
1
+
r
E
x
p
(
z
)
h(z)=\frac{1}{1+rExp(z)}
h(z)=1+rExp(z)1的计算是先求
r
E
x
p
rExp
rExp然后求倒数:
倒数则是采用Goldschmidt’s迭代近似算法实现,算法如下:
Tanh和sigmoid存在数学上的关系:
T
a
n
h
(
z
)
=
e
z
−
e
−
z
e
z
+
e
−
z
=
2
s
i
g
m
o
i
d
(
2
z
)
−
1
Tanh(z)=\frac{e^z-e^{-z}}{e^z+e^{-z}}=2sigmoid(2z)-1
Tanh(z)=ez+e−zez−e−z=2sigmoid(2z)−1,所以可以用如上方式实现。
6.3 平方根倒数
计算
r
s
q
r
t
(
x
)
=
1
x
rsqrt(x)=\frac{1}{\sqrt x}
rsqrt(x)=x1,为了防止分母为0,首先加上一个很小的
ϵ
\epsilon
ϵ有
r
s
q
r
t
(
x
)
=
1
x
+
ϵ
rsqrt(x)=\frac{1}{\sqrt {x+\epsilon}}
rsqrt(x)=x+ϵ1。
首先,进行初始化,然后用Goldschmidt法进行迭代,
算法如下:
通信开销和轮次汇总
参考资料:
SIRNN: A Math Library for Secure RNN Inference
2021-10-02-SIRNN