简单介绍
CrypTen是Facebook在2019年10月开源的,用于多方安全计算(MPC)的框架。其底层依赖于深度学习框架PyTorch。
- 官网说明见: https://ai.facebook.com/blog/crypten-a-new-research-tool-for-secure-machine-learning-with-pytorch/.
- 源码地址:https://github.com/facebookresearch/CrypTen
使用场景
Alice提供模型,Bob提供数据, 加密的数据被分割并在2方参与计算,最终将计算结果合并解密得到结果。
使用示例
x_enc = crypten.cryptensor([1, 2, 3])
y_enc = crypten.cryptensor([4, 5, 6])
z_enc = x_enc + y_enc
隐藏明文
crypten支持两种分享类型:arithmetic(ArithmeticSharedTensor)和binary(BinarySharedTensor)。
通过cryptensor()方法构造的MPCTensor(继承了CrypTensor)中,默认使用arithmetic分享类型,主要属性_tensor
值存放的是ArithmeticSharedTensor。
ArithmeticSharedTensor主要属性是share
(等价于属性_tensor
)。
构造过程主要为如下几步:
- 使用FixedPointEncoder将tensor编码,结果为scaled Integer类型的tensor。
- PRZS(psuedo-random sharing of zero)方式生成tensor。利用pytorch的isend和irecv,使得任意1方拥有n方中2方(prev,next)的随机数生成器的种子。核心代码如下
next_seed = torch.tensor(numpy.random.randint(-2 ** 63, 2 ** 63 - 1, (1,)))
prev_seed = torch.LongTensor([0]) # placeholder
# Send random seed to next party, receive random seed from prev party
if world_size >= 2: # Otherwise sending seeds will segfault.
next_rank = (rank + 1) % world_size
prev_rank = (next_rank - 2) % world_size
req0 = comm.get().isend(tensor=next_seed, dst=next_rank)
req1 = comm.get().irecv(tensor=prev_seed, src=prev_rank)
req0.wait()
req1.wait()
else:
prev_seed = next_seed
# Seed Generators
comm.get().g0.manual_seed(next_seed.item())
comm.get().g1.manual_seed(prev_seed.item())
3方(P1,P2,P3)计算情况,g0和g1的种子(Seed)对应如下结果:
\ | P1 | P2 | P3 |
---|---|---|---|
g0 | S1 | S2 | S3 |
g1 | S3 | S1 | S2 |
def PRZS(*size):
"""
Generate a Pseudo-random Sharing of Zero (using arithmetic shares)
This function does so by generating `n` numbers across `n` parties with
each number being held by exactly 2 parties. One of these parties adds
this number while the other subtracts this number.
"""
tensor = ArithmeticSharedTensor(src=SENTINEL)
current_share = generate_random_ring_element(*size, generator=comm.get().g0)
next_share = generate_random_ring_element(*size, generator=comm.get().g1)
tensor.share = current_share - next_share
return tensor
- 主节点(rank = 0) 的tensor = tensor + PRZS.share. 其他节点tensor = PRZS.share.
获取明文
获取方式是x_enc.get_plain_text()
方法,该方法主要步骤如下:
- 由于n方的PRZS.share的所有值相加为0. 通过pytorch的reduce方式,sum所有节点的tensor.share即可恢复原始tensor。
- 使用FixedPointEncoder将tensor解码
算数运算
“add”,“sub”,“mul”,“matmul"等几个方法在ArithmeticSharedTensor中实现。
其中"mul”,"matmul"等方法使用了Beaver Triples(Multiplication Triples)协议。协议具体过程如下:
- Obtain uniformly random sharings [a],[b] and [c] = [a * b]
- Additively hide [x] and [y] with appropriately sized [a] and [b]
- Open ([epsilon] = [x] - [a]) and ([delta] = [y] - [b])
- Return [z] = [c] + (epsilon * [b]) + ([a] * delta) + (epsilon * delta)
推理过程如下:
x * y =(x - a + a)(y - b + b)
= (epsilon + a)(delta + b)
= c + (epsilon * b) + (a * delta) + (epsilon * delta)
[x * y] = [c] + (epsilon * [b]) + ([a] * delta) + (epsilon * delta)
第三步在crypten具体实现中,是通过pytorch的reduce将epsilon和delta做了sum(所有节点值相等)。
a,b,c的生成方式
a,b,c的生存方式可以通过环境变量“CRYPTEN_PROVIDER_NAME”来设置。crypten支持3中方式:TFP(TrustedFirstParty,默认),TTP(TrustedThirdParty),HE(HomomorphicProvider)。
但目前HomomorphicProvider代码并未完成,TrustedThirdParty测试没有通过。只有TFP是可以的。
通信
CrypTen遵循标准的MPI编程模型:它为每一方运行一个单独的进程,但是每个进程运行一个相同的(完整)程序。 每个进程都有一个rank来表示自己。
crypten 分布式计算依赖的pytorch。crypten默认的配置DISTRIBUTED_BACKEND=gloo。
模型保护
模型的构建过程:
- 读取pytorch模型文件
- 将pytorch模型转换为onnx(Open Neural Network Exchange)文件
- 解读onnx文件,并构造crypten_model(Graph类或者Module类。Graph继承Container,Container继承Module)
使用场景:Alice提供模型,Bob提供数据
import crypten.mpc as mpc
import crypten.communicator as comm
ALICE = 0
BOB = 1
labels = torch.load('/tmp/bob_test_labels.pth').long()
count = 100 # For illustration purposes, we'll use only 100 samples for classification
@mpc.run_multiprocess(world_size=2)
def encrypt_model_and_data():
# Load pre-trained model to Alice
model = crypten.load('models/tutorial4_alice_model.pth', dummy_model=dummy_model, src=ALICE)
# Encrypt model from Alice
dummy_input = torch.empty((1, 784))
private_model = crypten.nn.from_pytorch(model, dummy_input)
private_model.encrypt(src=ALICE)
# Load data to Bob
data_enc = crypten.load('/tmp/bob_test.pth', src=BOB)
data_enc2 = data_enc[:count]
data_flatten = data_enc2.flatten(start_dim=1)
# Classify the encrypted data
private_model.eval()
output_enc = private_model(data_flatten)
# Compute the accuracy
output = output_enc.get_plain_text()
accuracy = compute_accuracy(output, labels[:count])
print("\tAccuracy: {0:.4f}".format(accuracy.item()))
encrypt_model_and_data()