一、PPML in SPU
1.PPML in SPU ML和MPC技术栈对比
(1).ML
• Forward/Backward
computations
• Tensors
Operations
• CNN/Transformers/GNN
SVM/K-means
• SGD/Adam/AMSGrad
(2).MPC
• Secret Sharing
Malicious security
• Addition/Multiplication
AND/XOR
• Secret Sharing
Yao's garbled circuits
• Honest/Dishonest
majority
• Mod
Prime/2的k次方
2.PPML in SPU SPU是什么
(1).前端:
• 编程语言为python
• 原生支持主流AI前端
• 对机器学习专家,额外学习成本低
• 复用AI框架前端能力,如自动求导
(2).编译器:
• 带隐私保护语义的IR
• 复用AI框架的部分编译优化
• MPC语义下的独占优化和翻译
(3).运行时:
• 多并发(指令,数据)
• 多协议支持(semi2k,cheetah,aby3)
• 部署模式透明(一次书写,到处执行)
二、浮点数和定点数
1.浮点数表示
2.定点数表示
3.浮点数和定点数
(1).浮点数
• 0点密集分布
• 取值范围较大
• 计算相对复杂(高次多项式,Lookup Table,…)
• 计算精度高,所有算子有严格的误差证明(<3 bit)
(2).定点数
• 均匀分布
• 取值范围较小
• 计算简单(整数运算,低次多项式)
• 计算精度较低,乘法依赖truncation,数学函数,通常使用近似解(不同算子误差范围不同)
三、明文算法迁移流程
1.将算法用jax的api去重新实现
Jax.numpy: numpy的替换
Jax.lax: low level数学算子
Jax.scipy : scipy的替换
2. 测试密态下的数值精度
• 模拟在定点数上运行的所有操作
• 可以提供真实的数值计算精度环
境,运行速度更快,快速实验
3. 测试密态下的实际性能
• 在真实的MPC协议上通过
多进程/Docker进行仿真
• 提供算法有效的性能结果
Step1.1: 使用JAX实现算法
Step1.2: 修改的insight
Step1.3: 明文下精度验证
Step2: SPU下验证精度(simulation)
Step2.1: 定义simulator
• 支持多种MPC协议
• 支持改变ring大小
• 支持改变fxp
• ……
•自定义config
Step2.2: 定义运行函数
Step2.3: 运行密态程序
• Cheetah协议运行
• ABY3协议运行
Step2.4: 分析和定位问题
• 思路1:从算法本身去看
可以尝试的idea:
[1].增大batch_size:减少梯度为0的概率
[2].增大eps:减少极端情况下学习率的抖动
[3].使用不带截断的Sigmoid近似(SR):减少梯度为0的概率
• 思路2:从MPC的角度去看
[1].Spu中能表示的最小正数为: 2!18 ≈ 3.8147e−06
[2].Cheetah协议计算乘法可能会产生0-2bit的误差
[3].猜想:若梯度很小(接近2!18),则平方运算造成的0-2bit误差可能会对结果造成显著的影响,甚至可能从最小正数溢出到最小负数。
• 改进思路仍然会以减少梯度为0入手:
[1].增大batch_size:减少梯度为0的概率
[2].增大eps:减少极端情况下学习率的抖动
[3].使用不带截断的Sigmoid近似(SR):减少梯度为0的概率
[4].增大fxp和环大小:提高表示精度,减少下溢到0点概率
Step3: SPU下验证性能(emulation)
Step3.1: 定义emulator
支持多种MPC协议
• 支持改变ring大小
• 支持改变fxp
• ……
Step3.2: 准备数据
• 与simulation不同,此时需要先将明文数据”密封” ,否则SPU会将其视作Public数据而无法密态下计算!
Step3.3: 运行程序
Step3.4: 根据cost profile优化性能
• e.g. 比较 div + norm 和 rsqrt + square
• ABY3协议下计算1000维向量的dk
四、常见问题
• Q1:怎么知道SPU支持哪些算子?
1. 自己动手: simulation, emulation
2. 参考文档:
https://www.secretflow.org.cn/docs/spu/latest/en-US/reference/np_op_status
https://www.secretflow.org.cn/docs/spu/latest/en-US/reference/xla_status
• Q2:怎么知道非线形算子大致的误差范围?
误差来自两个方面:
1. 系统设定误差(如环大小,fxp大小,truncation协议等)
2. 非线性算子拟合误差
因此很难给出如浮点数的误差估计,以下文档给出了一些数学算子的大致误差:
https://www.secretflow.org.cn/docs/spu/latest/en-US/development/fxp
• Q3:为什么明文下运行正常,密态下运行错误?
若运行报错:
1. 实现的算法是否jitable(即使用@jax.jit是否能运行)
2. 是否使用了SPU不支持的算子
若能运行,但误差极大,可以自查是否有以下情况:
1. 是否可能发生溢出:输入数据或参数是否太大或太小
2. SPU内部是否使用了浮点随机数生成器
3. 是否调用了线形代数算子(如矩阵分解,奇异值分解等)
若误差适中:可以考虑增大环的大小,提高fxp精度
• Q4:为什么Emulation的速度比simulation快很多?
数据没有seal(即load到PYU),SPU将其视为Public数据,所有计算在
明文下进行
• Q5:如何对密态算法进行优化?
有以下几个思路可以参考:
1. 减少耗时算子的调用(计算公式重写,多项式近似等)
2. 避免重复计算(空间换效率)
3. 并行化