【隐私计算】SIRNN: A Math Library for Secure RNN Inference

刚开始学隐私计算,读到SIRNN,感觉真的好难好难,门槛比deep learning高好多,先尽量啃一啃(捂脸.jpg)。


1 文章及代码

Paper: SIRNN: A Math Library for Secure RNN Inference
Code: https://github.com/mpc-msri/EzPC.

2 主要贡献

  • 为数学函数(指数、sigmoid、tanh、平方根倒数)提出了全新的密码学友好的新近似。
  • 为不均匀(混合)的bitwidth提供2PC协议,实现高效的数学函数。
  • SIRNN首次为RNN和CNN提供了安全推理库,在延迟、通信等方面达到SOTA,并拥有高数值准确率。

3 概览

3.1 Scale和bitwidth

2PC在整数上运算比在浮点数上运算更高效,在定点运算数中, ⌊ r 2 s ⌋ m o d    2 l \lfloor r2^s\rfloor \mod 2^l r2smod2l,其中 l l l就是bitwidth, s s s是scale。

3.2 对数学函数的近似

  • 首先用lookup table (LUT)得到一个不错的初始化近似,然后用迭代算法提升这个近似。
  • 更大的LUT近似结果更准确,但是通信开销线性增长。
  • 对于指数和负数输入,分解输入x到更小的子串(digit decomposition)。
  • 为了让迭代算法更高效,本文采用定点(fixed-point)算数及不均匀的混合bitwidth。

3.3 SIRNN协议

安全参数: λ = 128 \lambda=128 λ=128
基于4种构造块:
(1)Extension(扩展)
Z 2 m → Z 2 n ( m < n ) \mathbb{Z}_{2^m} \rightarrow \mathbb{Z}_{2^n} (m<n) Z2mZ2n(m<n)
GC需要的通信开销(重构和重共享)为: λ ( 4 m + 2 n ) \lambda(4m+2n) λ(4m+2n) bits,SIRNN需要的通信开销仅为: λ m \lambda m λm,大约比GC快6x。
(2)Truncation(截断)
常用于乘法之后减小规模,对于 l l l-bit截断了 s s s-bit有四种截断操作:

  • 逻辑右移(保留位宽)
  • 算数右移(保留位宽)
  • 截断且减小(输出截断值 Z 2 l − s \mathbb Z_{2^{l-s}} Z2ls
  • 除以 2 s 2^s 2s

目前最好的算数右移通信大约是: λ ( l + s ) \lambda(l+s) λ(l+s),本文提出的逻辑/算数右移协议大约是 λ l \lambda l λl,大多数数学函数都只需要截断且减小去减小scale和bitwidth,SIRNN只需要 λ ( s + 1 ) \lambda (s+1) λ(s+1)通信。
(3)Multiplication(乘法)
m m m-bit整数和 n n n-bit整数相乘得到 l = ( m + n ) l=(m+n) l=(m+n)-bit输出, l l l的选择保证了没有溢出。
(4)Digit Decomposition(数位分解)
l l l-bit的值分解为 c = l / d c=l/d c=l/d d d d-bits,可以用GC实现,通信量为 λ ( 6 l − 2 c − 2 ) \lambda (6l-2c-2) λ(6l2c2) bits。本文进一步优化,通信量为 λ ( c − 1 ) ( d + 2 ) \lambda (c-1)(d+2) λ(c1)(d+2) bits,大约比GC低5x。

4 前提知识

4.1 ULP误差(units in last place)

ULP是真实数据和函数输出值之间的可表示数值的数量。

4.2 威胁模型

  • 两方安全计算(2PC)
  • 静态的半诚实攻击:遵循协议,但是会学习额外信息

4.3 符号表示

符号意义
x ∈ Z 2 l x\in \mathbb Z_{2^l} xZ2lpower-of-2 rings, x x x的环为 Z 2 l \mathbb Z_{2^l} Z2l,即以 2 l 2^l 2l为模
B B Bring Z 2 \mathbb Z_2 Z2,即以2为模
λ \lambda λ计算安全系数
⊕ \oplus 异或门
ζ l , ζ l , m ( m > l ) \zeta_l, \zeta_{l,m} (m>l) ζl,ζl,m(m>l)无损lifting操作,映射 Z L → Z \mathbb Z_L\rightarrow \mathbb Z ZLZ,映射 Z L → Z M \mathbb Z_L\rightarrow \mathbb Z_M ZLZM
L , M , N L,M,N L,M,N 2 l , 2 m , 2 n 2^l, 2^m, 2^n 2l,2m,2n
[ k ] [k] [k] 0 , 1 , . . , k − 1 {0, 1, .., k-1} 0,1,..,k1
1 { b } 1\{b\} 1{b} b = t r u e b=true b=true时为1,反之为0
i n t ( x ) int(x) int(x) u i n t ( x ) uint(x) uint(x)对于 x ∈ Z l x\in \mathbb Z^l xZl,分别代表有符号和无符号值,int(x)=uint(x)−MSB(x)L
MSB(x)MSB(x) = 1 { x ≥ 2 l − 1 } =1\{x\geq 2^{l-1}\} =1{x2l1},表示最有效高位
F M i l l l ( x , y ) F_{Mill}^l(x, y) FMilll(x,y) F M i l l l ( x , y ) = ⟨ z ⟩ B = 1 { x < y } F_{Mill}^l(x, y)=\langle z\rangle^B=1\{x<y\} FMilll(x,y)=zB=1{x<y}在这里插入图片描述
F w r a p l F_{wrap}^l Fwrapl F w r a p l = F M i l l l ( L − 1 − x , y ) : w = w r a p ( x , y , L ) = 1 { x + y ≥ L } F_{wrap}^l=F_{Mill}^l(L-1-x, y): w=wrap(x, y, L)=1\{x+y\geq L\} Fwrapl=FMilll(L1x,y):w=wrap(x,y,L)=1{x+yL}
e e e e = 1 { ( x + y m o d    L ) = L − 1 } e=1\{(x+y \mod L)=L-1\} e=1{(x+ymodL)=L1},判断是否全是1
F w r a p & a l l 1 s l F_{wrap\&all1s}^l Fwrap&all1sl F w r a p & a l l 1 s l ( x , y ) = ( ⟨ w ⟩ B ∣ ∣ ⟨ e ⟩ B ) F_{wrap\&all1s}^l(x,y)=(\langle w\rangle^B||\langle e\rangle^B) Fwrap&all1sl(x,y)=(wBeB),至多一项是1
∗ m *_m m x ∗ m y = x y m o d    M x*_m y=xy\mod M xmy=xymodM,从 Z × Z → Z M \mathbb Z \times \mathbb Z \rightarrow \mathbb Z_M Z×ZZM
l l lbitwidth
s s sscale
l − s l-s ls整数部分的bitwidth
F i x ( x , l , s ) Fix(x, l, s) Fix(x,l,s) F i x ( x , l , s ) = x 2 s m o d    L Fix(x, l, s)=x2^s \mod L Fix(x,l,s)=x2smodL,从实数转到定点数表示
u r t ( l , s ) ( a ) urt_{(l,s)}(a) urt(l,s)(a)对于无符号数, u r t ( l , s ) ( a ) = u i n t ( a ) / 2 s urt_{(l,s)}(a)=uint(a)/2^s urt(l,s)(a)=uint(a)/2s,从定点数转到实数表示
s r t ( l , s ) ( a ) srt_{(l,s)}(a) srt(l,s)(a)对于有符号数, s r t ( l , s ) ( a ) = i n t ( a ) / 2 s srt_{(l,s)}(a)=int(a)/2^s srt(l,s)(a)=int(a)/2s,从定点数转到实数表示
> > L , > > A >>_L, >>_A >>L,>>A逻辑右移和算术右移

4.4 密码学基础

  • 秘密共享(SS)
    2-out-of-2加性秘密共享: x = ⟨ x ⟩ 0 l + ⟨ x ⟩ 1 l m o d    L x=\langle x\rangle_0^l+\langle x\rangle_1^l \mod L x=x0l+x1lmodL
  • 不经意传输(OT)
    1-out-of-k OT,用OT Extension (OTE)实现,并用了Correlated OT (COT)。

4.5 2PC基本函数

  • 百万富翁/wrap
    F M i l l l = 1 { x < y } F_{Mill}^l=1\{x<y\} FMilll=1{x<y},CrypTFlow2中通信量低于 λ l + 14 l \lambda l+14l λl+14l bits和 log ⁡ l \log l logl rounds。
    F w r a p l = F M i l l l ( L − 1 − x , y ) : w = w r a p ( x , y , L ) = 1 { x + y ≥ L } F_{wrap}^l=F_{Mill}^l(L-1-x, y): w=wrap(x, y, L)=1\{x+y\geq L\} Fwrapl=FMilll(L1x,y):w=wrap(x,y,L)=1{x+yL}
  • AND
    输入 ⟨ x ⟩ B , ⟨ y ⟩ B \langle x\rangle^B, \langle y\rangle^B xB,yB,输出 ⟨ x ∧ y ⟩ B \langle x \land y\rangle^B xyB,用Beaver bit-triples实现,CrypTFlow2中通信量为 λ + 20 \lambda+20 λ+20
  • Boolean to Arithmetic (B2A)
    输入boolean share,输出相同值的算术share,采用COT协议实现,通信量为 λ + l \lambda+l λ+l bits。
  • Multiplexer (MUX)
    ⟨ x ⟩ B \langle x\rangle^B xB ⟨ y ⟩ l \langle y\rangle^l yl作为输入,输出 ⟨ z ⟩ l \langle z\rangle^l zl,如果 x = 1 x=1 x=1,则 z = y z=y z=y,反之同理。本文提出的协议将通信量从 2 ( λ + 2 l ) 2(\lambda+2l) 2(λ+2l)(CrypTFlow2)降到 2 ( λ + l ) 2(\lambda+l) 2(λ+l)
  • Lookup Table (LUT)
    对于表 T T T M M M个入口,每个 n n n-bits,输入 ⟨ x ⟩ m \langle x\rangle^m xm ⟨ z ⟩ n \langle z\rangle^n zn,满足 z = T [ x ] z=T[x] z=T[x]。可以用1-out-of-m OT实现,通信量为 2 λ + M n 2\lambda +Mn 2λ+Mn bits。这是个查表的操作,输入和输出的位数是不同的。

5 构建块协议

5.1 零扩展和有符号扩展

对于 m m m-bit的数 x ∈ Z M x\in \mathbb Z_M xZM,将其转换为 n n n-bit的数( n > m n>m n>m),这个过程就称为扩展(extension)。零扩展和有符号扩展分别用于扩展无符号数和有符号数的位宽。
零扩展(Zero Extension)
P 0 P_0 P0 P 1 P_1 P1两方输入 ⟨ x ⟩ m \langle x\rangle^m xm,扩展输出 ⟨ y ⟩ n \langle y\rangle^n yn,要求满足 u n i t ( x ) = u i n t ( y ) unit(x)=uint(y) unit(x)=uint(y)。对于 x m ∈ Z M x^m\in \mathbb Z_M xmZM,可以得到 【问:这个等式在后面广泛使用,没太理解怎么来的】【答:其实 − w M -wM wM就是实现的 m o d    M \mod M modM计算过程,防止求和在 Z 2 m \mathbb Z_{2^m} Z2m环上溢出】
x m = ⟨ x ⟩ 0 m + ⟨ x ⟩ 1 m − w M x^m = \langle x \rangle_0^m+\langle x \rangle_1^m-wM xm=x0m+x1mwM
其中, w = w r a p ( ⟨ x ⟩ 0 m , ⟨ x ⟩ 1 m , M ) w=wrap(\langle x \rangle_0^m, \langle x \rangle_1^m, M) w=wrap(x0m,x1m,M),这是个boolean share,需要转换为算术share。这里考虑在 n − m n-m nm环上转换,原因就是下面的模约减步骤会使通信量大大降低。
F B 2 A n − m ( ⟨ w ⟩ B ) = ⟨ w ⟩ n − m ∈ Z 2 n − m F_{B2A}^{n-m}(\langle w\rangle^B)=\langle w\rangle^{n-m}\in \mathbb Z_{2^{n-m}} FB2Anm(wB)=wnmZ2nm

w = ⟨ w ⟩ 0 n − m + ⟨ w ⟩ 1 n − m − w r a p ( ⟨ w ⟩ 0 n − m , ⟨ w ⟩ 1 n − m , Z 2 n − m ) 2 n − m w = \langle w\rangle_0^{n-m} + \langle w\rangle_1^{n-m}-wrap(\langle w\rangle_0^{n-m}, \langle w\rangle_1^{n-m}, \mathbb Z_{2^{n-m}})2^{n-m} w=w0nm+w1nmwrap(w0nm,w1nm,Z2nm)2nm

M ∗ n w = M ∗ n ( ⟨ w ⟩ 0 n − m + ⟨ w ⟩ 1 n − m − w r a p ( ⟨ w ⟩ 0 n − m , ⟨ w ⟩ 1 n − m , Z 2 n − m ) 2 n − m ) M_{*n}w = M_{*n}(\langle w\rangle_0^{n-m} + \langle w\rangle_1^{n-m} - wrap(\langle w\rangle_0^{n-m}, \langle w\rangle_1^{n-m}, \mathbb Z_{2^{n-m}})2^{n-m}) Mnw=Mn(w0nm+w1nmwrap(w0nm,w1nm,Z2nm)2nm)

其中, M ∗ n w r a p ( ⋅ ) 2 n − m = M w r a p ( ⋅ ) 2 n − m m o d    N = w r a p ( ⋅ ) 2 n m o d    N = 0 M_{*n}wrap(\cdot)2^{n-m}=Mwrap(\cdot)2^{n-m} \mod N=wrap(\cdot)2^{n} \mod N=0 Mnwrap()2nm=Mwrap()2nmmodN=wrap()2nmodN=0(这一步称作“模约减”,modulo-reduce),所以上式子转换为:
M ∗ n w = M ∗ n ( ⟨ w ⟩ 0 n − m + ⟨ w ⟩ 1 n − m ) M_{*n}w = M_{*n}(\langle w\rangle_0^{n-m} + \langle w\rangle_1^{n-m}) Mnw=Mn(w0nm+w1nm
于是:
y = ∑ b = 0 1 ( ⟨ x ⟩ b m − M ⟨ w ⟩ b n − m ) m o d    N y = \sum_{b=0}^1(\langle x\rangle_b^m-M\langle w\rangle_b^{n-m}) \mod N y=b=01(xbmMwbnm)modN
这里是在 P 0 P_0 P0 P 1 P_1 P1上分别计算,然后求和取模,得到扩展后的结果。其中, x m o d    N = y x \mod N=y xmodN=y
算法如下:
在这里插入图片描述
需要 log ⁡ ( m + 2 ) \log(m+2) log(m+2) rounds和少于 λ ( m + 1 ) + 13 m + n \lambda(m+1)+13m+n λ(m+1)+13m+n bits的通信量。作为对比,用GC实现零扩展和有符号扩展需要 λ ( 4 m + 2 n − 4 ) \lambda(4m+2n-4) λ(4m+2n4) bits的通信量,大约是SIRNN的6倍。

有符号扩展(Signed Extension)
有符号扩展可以基于以下等式,通过转换无符号扩展得到,在环 Z \mathbb Z Z上:
i n t ( x ) = x ′ − 2 m − 1 , x ′ = x + 2 m − 1 m o d    M int(x)=x'-2^{m-1}, x'=x+2^{m-1} \mod M int(x)=x2m1,x=x+2m1modM
证明如下:
在这里插入图片描述
于是:
S E x t ( x , m , n ) = Z E x t ( x , m , n ) − 2 m − 1 SExt(x, m, n)=ZExt(x, m, n)-2^{m-1} SExt(x,m,n)=ZExt(x,m,n)2m1

相比零扩展,没有额外的通信开销。

5.2 截断

首先,规定 > > L , > > A >>_L,>>_A >>L,>>A分别表示逻辑右移和算术右移,它们的输入和输出都是在 Z L \mathbb Z_L ZL环上。
T R ( x , s ) TR(x, s) TR(x,s)表示截断且减小(truncate & reduce),将 x ∈ Z L x\in \mathbb Z_L xZL截断且减小 s s s-bits,最终得到的 x x x在更小的 Z 2 l − s \mathbb Z_{2^{l-s}} Z2ls环上。
逻辑右移
Toy example: x = 101001 x=101001 x=101001逻辑右移3位,则 x ′ = 000101 x'=000101 x=000101(右侧截掉,左侧补0)。
对于 x ∈ Z L x\in \mathbb Z_L xZL,则 x = ⟨ x ⟩ 0 l + ⟨ x ⟩ 1 l m o d    L x=\langle x\rangle_0^l+\langle x\rangle_1^l \mod L x=x0l+x1lmodL,记 ⟨ x ⟩ b l = u b ∣ ∣ v b \langle x\rangle_b^l=u_b||v_b xbl=ubvb u b u_b ub是高位, v b v_b vb是低位),其中 u b ∈ { 0 , 1 } l − s , v b ∈ { 0 , 1 } s u_b\in\{0, 1\}^{l-s}, v_b\in\{0, 1\}^{s} ub{0,1}ls,vb{0,1}s。如下图:
在这里插入图片描述
根据前面提到的公式:
x m = ⟨ x ⟩ 0 m + ⟨ x ⟩ 1 m − w M x^m = \langle x \rangle_0^m+\langle x \rangle_1^m-wM xm=x0m+x1mwM
可以得到:
x > > L s = u 0 + u 1 − 2 l − s w r a p ( ⟨ x ⟩ 0 l , ⟨ x ⟩ 1 l , L ) + w r a p ( v 0 , v 1 , 2 s ) x>>_Ls=u_0+u_1-2^{l-s} wrap (\langle x\rangle_0^l, \langle x\rangle_1^l, L) + wrap(v_0, v_1, 2^s) x>>Ls=u0+u12lswrap(x0l,x1l,L)+wrap(v0,v1,2s)
上式中, w r a p ( v 0 , v 1 , 2 s ) wrap(v_0, v_1, 2^s) wrap(v0,v1,2s)这一项是考虑了进位。我们知道,加性秘密共享时, v v v部分可能会存在1位进位的情况,所以 w r a p ( v 0 , v 1 , 2 s ) wrap(v_0, v_1, 2^s) wrap(v0,v1,2s)就是判断 v 0 + v 1 v_0+v_1 v0+v1是否大于 2 s 2^s 2s,如果是,则会进1,如果不是,则为0。
常规做法是计算两个 w r a p ( ⋅ ) wrap(\cdot) wrap()值即可,但是SIRNN提出了一种优化,避开直接计算位宽是 l l l的那一项。文章中的Lemma 1即是这个引理:
在这里插入图片描述
通信开销低于 λ ( l + 3 ) + 15 + s + 20 \lambda(l+3)+15+s+20 λ(l+3)+15+s+20,并需要 log ⁡ l + 3 \log l+3 logl+3 rounds。
原文证明如下:
在这里插入图片描述

算法如下:
在这里插入图片描述

算术右移
对于无符号数,直接采用逻辑右移,对于有符号数,则需要采用算术右移。从前面零扩展到有符号扩展可以知道: i n t ( x ) = x ′ − 2 l − 1 , x ′ = x + 2 l − 1 m o d    L int(x)=x'-2^{l-1}, x'=x+2^{l-1} \mod L int(x)=x2l1,x=x+2l1modL,于是:
x > > A s = x > > L s − 2 l − s − 1 x>>_As = x>>_Ls-2^{l-s-1} x>>As=x>>Ls2ls1

截断且减小
Toy example: x = 101001 x=101001 x=101001截断且减小3位,则 x ′ = 101 x'=101 x=101
因为 2 l − s ∗ l w m o d    2 l − s = 0 2^{l-s}{*_l} w \mod 2^{l-s}=0 2lslwmod2ls=0(模约减),所以:
⟨ T R ( x , s ) ⟩ l − s = u 0 + u 1 + w r a p ( v 0 , v 1 , 2 s ) \langle TR(x, s)\rangle^{l-s}=u_0+u_1+wrap(v_0, v_1, 2^s) TR(x,s)ls=u0+u1+wrap(v0,v1,2s)

除以power-of-2
z < 0 , z = ⌈ i n t ( x ) / 2 s ⌉ m o d    L ; z ≥ 0 , z = ⌊ i n t ( x ) / 2 s ⌋ m o d    L z<0, z=\lceil int(x)/2^s\rceil \mod L; z\geq0, z=\lfloor int(x)/2^s\rfloor \mod L z<0,z=int(x)/2smodL;z0,z=int(x)/2smodL
实际上 i n t ( x ) / 2 s m o d    L int(x)/2^s \mod L int(x)/2smodL就是做 > > A >>_A >>A,取整括号即是将值往0靠近。令 m x = 1 { x ≥ 2 l − 1 } m_x=1\{x\geq 2^{l-1}\} mx=1{x2l1}判断 x x x的正负性, c = 1 { x m o d    2 s = 0 } c=1\{x\mod 2^s=0\} c=1{xmod2s=0}
m x = 1 m_x=1 mx=1,则 z < 0 , ⌈ z ⌉ z<0, \lceil z\rceil z<0,z;反之, ⌊ z ⌋ \lfloor z\rfloor z。所以有:
D i v P o w 2 ( x , s ) = ( x > > A s ) + m x ∧ c DivPow2(x, s)=(x>>_As)+m_x\land c DivPow2(x,s)=(x>>As)+mxc

5.3 混合位宽乘法

以前做乘法通常是用Beaver Triplet三元组实现,SIRNN中不能用了,因为加法和乘法的数bitwidth不一致。
无符号乘法
输入 ⟨ x ⟩ m , ⟨ y ⟩ n \langle x\rangle^m, \langle y\rangle^n xm,yn,输出 ⟨ z ⟩ l , z = x ∗ l y , l = n + m \langle z\rangle^l, z=x*_l y, l=n+m zl,z=xly,l=n+m
对于 x , y x,y x,y,在 Z \mathbb Z Z上有:
u i n t ( x ) ⋅ u i n t ( y ) = ( x 0 + x 1 − 2 m w x ) ⋅ ( y 0 + y 1 − 2 n w y ) = x 0 y 0 + x 0 y 1 + x 1 y 0 + x 1 y 1 − 2 m w x y − 2 n w y x + 2 l w x w y uint(x)\cdot uint(y)=(x_0+x_1-2^mw_x)\cdot(y_0+y_1-2^nw_y)\\=x_0y_0+x_0y_1+x_1y_0+x_1y_1-2^mw_xy-2^nw_yx+2^lw_xw_y uint(x)uint(y)=(x0+x12mwx)(y0+y12nwy)=x0y0+x0y1+x1y0+x1y12mwxy2nwyx+2lwxwy
观察上式, x 0 y 0 , x 1 y 1 x_0y_0,x_1y_1 x0y0,x1y1都是可以本地计算的【本地计算为什么不管位宽是否一致?】 2 l w x w y 2^lw_xw_y 2lwxwy可以在 m o d    L \mod L modL时被消掉(模约减), w x y , x y x w_xy, x_yx wxy,xyx是boolean share和算术share的计算,本质上是MUX,可用直接用OT实现。最难的一项是交叉项 x 0 y 1 , x 1 y 0 x_0y_1, x_1y_0 x0y1,x1y0,SIRNN采用COT实现。
巧妙的一点在于:选择比特位短的一方作为receiver,比特位长的一方作为sender,这样在做OT的取数时,round数就会更少。
交叉项算法如下:

无符号乘法算法如下:
在这里插入图片描述
SIRNN利用1-out-of-2的COT来实现这个过程,将短的数按位拆解,每一位非0即1,然后做二选一的COT,每一位计算完成后,在本地累加起来。
通信开销大约是: λ ( 3 μ + v ) + μ ( μ + 2 v ) + 16 ( m + n ) \lambda(3\mu + v) + \mu(\mu + 2v) + 16(m + n) λ(3μ+v)+μ(μ+2v)+16(m+n),其中 μ = min ⁡ ( m , n ) , ν = max ⁡ ( m , n ) \mu = \min(m, n), ν = \max(m, n) μ=min(m,n),ν=max(m,n)。普通的扩展位数然后相乘的开销是: 3 λ ( μ + v ) + ( m + n ) 2 + 15 ( m + n ) 3\lambda(\mu+v)+(m+n)^2+15(m + n) 3λ(μ+v)+(m+n)2+15(m+n),大约是SIRNN的1.5x。

有符号乘法
布尔分享转换为算术分享:
⟨ x ⟩ A = ⟨ x ⟩ 0 B + ⟨ x ⟩ 1 B − 2 ⟨ x ⟩ 0 B ⟨ x ⟩ 1 B \langle x\rangle^A=\langle x\rangle_0^B+\langle x\rangle_1^B-2\langle x\rangle_0^B\langle x\rangle_1^B xA=x0B+x1B2x0Bx1B
基于前面无符号数和有符号数的关系,可以得到:无符号数 x ′ = x + 2 m − 1 m o d    M , y ′ = y + 2 n − 1 m o d    N x'=x+2^{m-1}\mod M, y'=y+2^{n-1}\mod N x=x+2m1modM,y=y+2n1modN。由秘密共享, x ′ = x 0 ′ + x 1 ′ m o d    M , y ′ = y 0 ′ + y 1 ′ m o d    N x'=x_0'+x_1' \mod M, y'=y_0'+y_1' \mod N x=x0+x1modM,y=y0+y1modN。有符号数 i n t ( x ) = x ′ − 2 m − 1 , i n t ( y ) = y ′ − 2 n − 1 int(x)=x'-2^{m-1}, int(y)=y'-2^{n-1} int(x)=x2m1,int(y)=y2n1。因此,在 Z \mathbb Z Z环上:
在这里插入图片描述
在这里插入图片描述

x ′ y ′ x'y' xy是无符号数的乘法,可以用algorithm 3计算, 2 m − 1 y b ′ , 2 n − 1 x b ′ 2^{m-1}y_b', 2^{n-1}x_b' 2m1yb,2n1xb也都可以在本地计算出来。难点是wrap项应该如何计算。
2 m + n − 1 w x ′ = 2 l − 1 w x ′ = 2 l − 1 ( ⟨ w x ′ ⟩ 0 B + ⟨ w x ′ ⟩ 1 B − 2 ⟨ w x ′ ⟩ 0 B ⟨ w x ′ ⟩ 1 B ) 2^{m+n-1}w_{x'}=2^{l-1}w_{x'}=2^{l-1}(\langle w_{x'}\rangle_0^B+\langle w_{x'}\rangle_1^B-2\langle w_{x'}\rangle_0^B\langle w_{x'}\rangle_1^B) 2m+n1wx=2l1wx=2l1(wx0B+wx1B2wx0Bwx1B)
其中, 2 ⟨ w x ′ ⟩ 0 B ⟨ w x ′ ⟩ 1 B 2\langle w_{x'}\rangle_0^B\langle w_{x'}\rangle_1^B 2wx0Bwx1B 2 l − 1 2^{l-1} 2l1相乘再 m o d    L \mod L modL后会被消除掉,所以无需计算。因此,上式变为:
2 m + n − 1 w x ′ = 2 l − 1 w x ′ = 2 l − 1 ( ⟨ w x ′ ⟩ 0 B + ⟨ w x ′ ⟩ 1 B ) 2^{m+n-1}w_{x'}=2^{l-1}w_{x'}=2^{l-1}(\langle w_{x'}\rangle_0^B+\langle w_{x'}\rangle_1^B) 2m+n1wx=2l1wx=2l1(wx0B+wx1B)
有符号的乘法相比无符号的乘法,也没有额外的开销。

矩阵乘法和卷积
矩阵乘法和卷积是很常见的(实际上可以展开为普通乘法做elment-wise乘和加),两个矩阵 A ∈ Z M d 1 × d 2 , A ∈ Z N d 2 × d 3 A\in \mathbb Z_M^{d1\times d2}, A\in \mathbb Z_N^{d2\times d3} AZMd1×d2,AZNd2×d3,输出矩阵乘法结果 A ∈ Z L d 1 × d 3 A\in \mathbb Z_L^{d1\times d3} AZLd1×d3,其中 l = m + n l=m+n l=m+n。做矩阵乘法需要 d 2 d_2 d2次乘以及 d 2 − 1 d_2-1 d21次加。
这个时候可能出现的问题是:加法导致溢出。一种解决方式是将element-wise乘后的结果扩展 e = ⌈ log ⁡ d 2 ⌉ e=\lceil \log d_2\rceil e=logd2-bits后,再做加法。但是,这样扩展开销很大,需要扩展 d 1 d 2 d 3 d_1d_2d_3 d1d2d3次。
于是本文这样做:考虑到前面算交叉项(CrossTerm)时,通信round数取决于较小的bitwidth,所以本文将bitwidth较大的一项拿去扩展 e e e-bits,在不增加开销的情况下,扩大了环。
通信开销大致为 λ ( 3 d 1 d 2 ( m + 2 ) + d 2 d 3 ( n + 2 ) ) + d 1 d 2 d 3 ( ( 2 m + 4 ) ( n + e ) + m 2 + 5 m ) \lambda(3d_1d_2(m+2)+d_2d_3(n+2))+d_1d_2d_3((2m+4)(n+e)+m^2+5m) λ(3d1d2(m+2)+d2d3(n+2))+d1d2d3((2m+4)(n+e)+m2+5m) bits。
算法如下:
在这里插入图片描述

乘且截断
首先调用有符号乘法,然后截断。输入 ⟨ x ⟩ m , ⟨ y ⟩ n \langle x\rangle^m, \langle y\rangle^n xm,yn,输出 ⟨ z ′ ⟩ l − s \langle z'\rangle^{l-s} zls z = i n t ( x ) ∗ l i n t ( y ) , z ′ = T R ( z , s ) z=int(x)*_l int(y), z'=TR(z, s) z=int(x)lint(y),z=TR(z,s)。其中 l = m + n l=m+n l=m+n

5.4 数值分解和MSNZB (Most Significant Non-Zero Bit)

数值分解
l l l-bit的数分解为 c c c个长度为 d = l / c d=l/c d=l/c的子串或数值,使得 x = z c − 1 ∣ ∣ . . . ∣ ∣ z 0 x=z_{c-1}||...||z_0 x=zc1...z0
算法如下:
在这里插入图片描述

MSNZB
返回最高非零比特的索引:比如 x = 001010 x=001010 x=001010返回的就是3。
算法如下:
在这里插入图片描述

5.5 MSB to Wrap Optimization

本文大量依赖于 w = w r a p ( ⟨ x ⟩ 0 l , ⟨ x ⟩ 1 l , L ) w=wrap(\langle x\rangle_0^l, \langle x\rangle_1^l, L) w=wrap(x0l,x1l,L),一些情况下,我们能得到 m x = M S B ( x ) m_x=MSB(x) mx=MSB(x) ⟨ m x ⟩ B \langle m_x\rangle^B mxB,于是 w = ( ( 1 ⊕ m x ) ∧ ( m 0 ⊕ m 1 ) ⊕ ( m 0 ∧ m 1 ) ) w=((1\oplus m_x)\land (m_0\oplus m_1)\oplus(m_0\land m_1)) w=((1mx)(m0m1)(m0m1)),其中 m b = M S B ( ⟨ x ⟩ B l ) m_b=MSB(\langle x\rangle_B^l) mb=MSB(xBl)。当 m x m_x mx是秘密分享时,使用 ( 4 1 ) \binom{4}{1} (14)-OT;当 m x m_x mx是明文时,使用 ( 2 1 ) \binom{2}{1} (12)-OT。

6 构建数学库

6.1 指数

r E x p ( z ) = e − z , z ∈ R + rExp(z)=e^{-z}, z\in \mathbb R^+ rExp(z)=ez,zR+的值,首先将输入 x x x分成 k k k段,然后每段在LUT (Look Up Table)进行查表,将得到的结果相乘。
算法如下:
在这里插入图片描述

6.2 Sigmoid和Tanh

s i g m o i d ( z ) = 1 1 + e − z sigmoid(z)=\frac{1}{1+e^{-z}} sigmoid(z)=1+ez1,可以表示如下:
在这里插入图片描述
其中, h ( z ) = 1 1 + r E x p ( z ) h(z)=\frac{1}{1+rExp(z)} h(z)=1+rExp(z)1的计算是先求 r E x p rExp rExp然后求倒数:
在这里插入图片描述
倒数则是采用Goldschmidt’s迭代近似算法实现,算法如下:
在这里插入图片描述
Tanh和sigmoid存在数学上的关系: T a n h ( z ) = e z − e − z e z + e − z = 2 s i g m o i d ( 2 z ) − 1 Tanh(z)=\frac{e^z-e^{-z}}{e^z+e^{-z}}=2sigmoid(2z)-1 Tanh(z)=ez+ezezez=2sigmoid(2z)1,所以可以用如上方式实现。

6.3 平方根倒数

计算 r s q r t ( x ) = 1 x rsqrt(x)=\frac{1}{\sqrt x} rsqrt(x)=x 1,为了防止分母为0,首先加上一个很小的 ϵ \epsilon ϵ r s q r t ( x ) = 1 x + ϵ rsqrt(x)=\frac{1}{\sqrt {x+\epsilon}} rsqrt(x)=x+ϵ 1
首先,进行初始化,然后用Goldschmidt法进行迭代,
算法如下:
在这里插入图片描述
在这里插入图片描述

通信开销和轮次汇总

在这里插入图片描述

参考资料:
SIRNN: A Math Library for Secure RNN Inference
2021-10-02-SIRNN

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Mr.zwX

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值