本周继续隐语第10课《PPML入门/基于SPU的机器学习建模实践》的学习,本次课程是一次实践课,由蚂蚁集团隐私计算部的吴豪奇老师做的分享。第10课依然是属于SPU系列的进阶课程,因此课程中关于SPU的介绍,可以参考之前我的课程笔记《SML入门/基于SPU实现明文算法迁移密文模型的实践》以及《隐语课程学习笔记8-理解密态引擎SPU框架》。
目录
4.1 作业1: 基于FLAX库实现简单的MLP模型并密态执行对比
4.2 作业2: 基于transformer库,调用gpt2,完成明文与密文的推理结果对比
1. 课程背景信息
随着数据隐私问题日益受到关注,国家对于数据安全重视程度提升,以及监管力度不断加强,机器学习领域同样面临数据安全的问题。数据是机器学习算法预测准确性等效果的基础保证,训练高质量模型需要大量的有效数据,同时训练好的模型对外提供服务,需要用户数据作为输入。这些数据可能包含生物信息、金融信息等敏感数据,需要通过相应的技术手段,达到数据流通的技术信任。隐语给出的解决方案之一是采用安全多方计算(MPC),目前在业内用的最广的是基于秘密分享机制的MPC计算模式。基于MPC来实现PPML,是一种可行的技术解。
关于SPU的介绍,核心系统组件主要分为三块:前端-机器学习程序;编译器-生成并优化SPU的IR;运行时-以MPC协议的方式执行。
2. 基于MPC的NN密态训练及推理
以逻辑回归为例,说明四个主要问题:
(1)数据从哪来:数据提供发Alice、Bob分别在本地加载数据,这里课程PPT中的代码和图片感觉没完全对应上,注意一下就行,代码中的P1方是加载x1、y, P2加载x2。而图中是P1加载x1,P2加载x2和y。
(2)如何加密保护数据:数据方对数据进行加密(数据碎片化),发送给MPC计算方(秘密分享),这里提到采用外包模式(一般数据保护要求高的场景不建议使用,因为很容易发生合谋被窃取数据),好处是可以利用计算方的算力资源。计算方拿到密文(数据碎片)。
(3)如何定义模型计算:使用JAX实现前向和后向传播,如果这里不清楚的话,后面的4层MLP模型作业可以加深认识。
(4)如何定义模型密态计算:计算方以密文数据作为输入,将模型的训练/推理计算图通过SPU编译器转换为相应密态算子计算图,由SPU device按照MPC协议逐个执行。
(5)配置说明
3. 复杂模型的密态训练和推理
对于DNN、CNN, 甚至transformer等复杂模型,需要涉及到stax/flax的使用。
(1)stax
是 JAX 的一个简洁的神经网络库,它提供了一种简单的方式来定义和训练神经网络模型。stax
以功能性编程风格为基础,提供了一组基本的层构建块,可以用来组合和构建复杂的神经网络。stax
提供了一个轻量级的 API,用于定义和训练神经网络。使用函数式编程范式,这使得模型定义和训练非常简洁和易于理解。完全基于 JAX 构建,因此可以利用 JAX 的自动微分和硬件加速功能。
(2)flax
是一个功能强大且灵活的神经网络库,它建立在 JAX 之上。flax
提供了更加丰富的功能和更高的灵活性,可以用来构建、训练和部署复杂的深度学习模型。flax
提供了丰富的模块和功能,可以用来构建复杂的神经网络架构。允许用户定义自定义层和模块,并且支持动态图和静态图的混合使用。同样利用 JAX 的自动微分和硬件加速功能,同时提供更高级的抽象和功能。
以gpt2为例,说明如何复用构建好的模型,来实现密态的计算。形式比较简单,修改的部分也不多。主要就是:(1)数据加载,需要指定某个device加载某类数据;(2)将对应的处理函数声明在SPU设备中执行,将加载的数据作为输入参数传入。这里应该会涉及到碎片化等操作,传入到对应函数计算的数据应该已经是密态了。
这里给出了密文计算与明文计算的对应关系,明文算子有其对应的密文算子,SPU应该逐步迭代完善中。
4. 课后作业
说明:建议使用1.3.0-amd64版本的secretnote,用1.5.0可能会出现一些算子版本问题引起的报错。
4.1 作业1: 基于FLAX库实现简单的MLP模型并密态执行对比
(1)定义breast_cancer数据集处理函数
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import Normalizer
def breast_cancer(party_id=None,train:bool = True) ->(np.ndarray, np.ndarray):
x, y = load_breast_cancer(return_X_y=True)
x = (x-np.min(x)) / (np.max(x) - np.min(x))
x_train,x_test,y_train,y_test = train_test_split(x,y,test_size=0.2,random_state=42)
if train:
if party_id:
if party_id== 1:
return x_train[:,:15],_
else:
return x_train[:,15:],y_train
else:
return x_train,y_train
else:
return x_test,y_test
(2)基于flax定义四层MLP模型,激活函数用relu
from typing import Sequence
import flax.linen as nn
FEATURES = [30, 15, 8, 1]
class MLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, x):
for feat in self.features[:-1]:
x = nn.relu(nn.Dense(feat)(x))
x = nn.Dense(self.features[-1])(x)
return x
(3)定义模型训练函数,使用MSE作为损失函数
import jax.numpy as jnp
def predict(params, x):
from typing import Sequence
import flax.linen as nn
FEATURES=[30,15,8,1]
class MLP(nn.Module):
features:Sequence[int]
@nn.compact
def __call__(self, x):
for feat in self.features[:-1]:
x = nn.relu(nn.Dense(feat)(x))
x = nn.Dense(self.features[-1])(x)
return x
return MLP(FEATURES).apply(params, x)
def loss_func(params, x, y):
pred = predict(params, x)
def mse(y, pred):
def squared_error(y, y_pred):
return jnp.multiply(y - y_pred, y - y_pred) / 2.0
return jnp.mean(squared_error(y, pred))
return mse(y, pred)
def train_auto_grad(x1, x2, y, params, n_batch=10, n_epochs=10, step_size=0.01):
x = jnp.concatenate((x1, x2),axis=1)
xs = jnp.array_split(x, len(x) / n_batch, axis=0)
ys = jnp.array_split(y, len(y) / n_batch, axis=0)
def body_fun(_, loop_carry):
params = loop_carry
for x, y in zip(xs, ys):
_, grads = jax.value_and_grad(loss_func)(params, x, y)
params = jax.tree_util.tree_map(
lambda p, g: p - step_size * g, params, grads
)
return params
params = jax.lax.fori_loop(0, n_epochs, body_fun, params)
return params
# Model init is purely public and run on SPU leads to significant accuracy loss, thus hoist out and run in python
def model_init(n_batch=10):
model = MLP(FEATURES)
return model.init(jax.random.PRNGKey(1), jnp.ones((n_batch, FEATURES[0])))
(4)采用AUC作为模型的性能指标
from sklearn.metrics import roc_auc_score
def validate_model(params,X_test,y_test):
y_pred=predict(params, X_test)
return roc_auc_score(y_test,y_pred)
(5)基于上述函数测试下明文训练MLP
import jax
# 加载数据
x1,_ =breast_cancer(party_id=1,train=True)
x2,y =breast_cancer(party_id=2,train=True)
# 超参数
n_batch=10
n_epochs = 10
step_size=0.01
# 训练模型
init_params=model_init(n_batch)
params=train_auto_grad(x1, x2, y, init_params, n_batch, n_epochs, step_size)
# 测试模型
X_test,y_test=breast_cancer(train=False)
auc= validate_model(params, X_test,y_test)
print(f'auc={auc}')
(6)基于SPU将上述明文训练转换为密文训练
import secretflow as sf
# Check the version of your SecretFlow
print('The version of SecretFlow:{}'.format(sf.__version__))
# In case you have a running scretflow runtime already.
sf.shutdown()
sf.init(['alice','bob'],address='local')
alice, bob = sf.PYU('alice'), sf.PYU('bob')
spu = sf.SPU(sf.utils.testing.cluster_def(['alice','bob']))
x1, _ = alice(breast_cancer)(party_id=1, train=True)
x2, y = bob(breast_cancer)(party_id=2, train=True)
init_params=model_init(n_batch)
device=spu
x1_,x2_,y_ = x1.to(device),x2.to(device),y.to(device)
init_params_ = sf.to(alice, init_params).to(device)
params_spu = spu(train_auto_grad, static_argnames=['n_batch','n_epochs','step_size'])(
x1_, x2_, y_, init_params_, n_batch=n_batch, n_epochs=n_epochs, step_size=step_size
)
(7)测试密文训练得到的模型效果
先将密文训练参数恢复成明文,然后利用恢复的明文参数对密文训练模型的预测结果进行评估(生产环境不建议这样使用,仅是为了验证本次实验的准确性)。可以看到密文的评估结果和明文模型基本一致,说明spu可以较好地在保持训练精度下的密文计算。
params=sf.reveal(params_spu)
X_test,y_test=breast_cancer(train=False)
auc=validate_model(params, X_test,y_test)
print(f'auc={auc}')
4.2 作业2: 基于transformer库,调用gpt2,完成明文与密文的推理结果对比
(1)安装transformers库
import sys
!{sys.executable} -m pip install transformers[flax]
(2)配置Huggingface镜像
import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
(3)加载预训练GPT-2模型
from transformers import AutoTokenizer, FlaxGPT2LMHeadModel, GPT2Config
tokenizer= AutoTokenizer.from_pretrained("gpt2")
pretrained_model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")
(4)定义文本生成函数和预训练模型参数的使用,这里我设置了自定义配置,为了计算加速,将gpt2的模型层数的选择改成了可以自由配置,因为gpt2是12层的结构,所以为了加速验证效率,可以选择前k层做快速实验。当然最后还是选择12层进行最终的测试。
import jax.numpy as jnp
layers_selected = 12 # 选择的topk层数
def text_generation(input_ids,params):
# 自定义配置
config = GPT2Config(
n_layer=layers_selected, # 仅使用选择的topk层
vocab_size=pretrained_model.config.vocab_size, # 继承预训练模型的词汇表大小
n_positions=pretrained_model.config.n_positions, # 继承预训练模型的最大位置
n_ctx=pretrained_model.config.n_ctx, # 继承预训练模型的上下文大小
n_embd=pretrained_model.config.n_embd, # 继承预训练模型的嵌入维度
n_head=pretrained_model.config.n_head, # 继承预训练模型的头数
)
model = FlaxGPT2LMHeadModel(config=config)
for _ in range(10):
outputs = model(input_ids=input_ids, params=params)
next_token_logits=outputs[0][0,-1,:]
next_token=jnp.argmax(next_token_logits)
input_ids = jnp.concatenate([input_ids, jnp.array([[next_token]])],axis=1)
return input_ids
# 获取预训练模型的参数
params = pretrained_model.params
# 只保留topk层的参数
new_params = {
'transformer': {
'wte': params['transformer']['wte'],
'wpe': params['transformer']['wpe'],
'h': {str(i): params['transformer']['h'][str(i)] for i in range(layers_selected)}, # 仅使用选择的topk层
'ln_f': params['transformer']['ln_f'],
}
}
(5)(明文)在CPU上生成文本,可以正常输出下一个词的结果。
import jax.numpy as jnp
inputs_ids = tokenizer.encode('I enjoy walking with my cute dog',return_tensors='jax')
outputs_ids = text_generation(inputs_ids, new_params)
print('-' * 65 + '\nRun on CPU:\'n' + '-' * 65)
print(tokenizer.decode(outputs_ids[0], skip_special_tokens=True))
print('-'*65)
(6) (密文)在SPU上生成文本,这里我遇到了内存OOM的问题,什么都不调整,直接使用spu执行,可以完成step1和step2,但是output_token_ids = spu(text_generation)(input_token_ids_, model_params_) 执行的过程中,直接出现了OOM,我本地的服务器内存不大,只有16G。因此为了能验证通过,将预训练的参数和输入token ids都设置在alice侧加载,推理也在alice侧执行,这样可以正常完成。后续调整下资源后,重新run spu的过程。
import secretflow as sf
# In case you have a running secretflow runtime alLready
sf.shutdown()
sf.init(['alice','bob'],address='local')
alice,bob=sf.PYU('alice'), sf.PYU('bob')
conf = sf.utils.testing.cluster_def(['alice', 'bob'])
conf['runtime_config']['fxp_exp_mode'] = 1
conf['runtime_config']['experimental_disable_mmul_split'] = True
spu = sf.SPU(conf)
def get_model_params():
pretrained_model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")
# 获取预训练模型的参数
params = pretrained_model.params
new_params = {
'transformer': {
'wte': params['transformer']['wte'],
'wpe': params['transformer']['wpe'],
'h': {str(i): params['transformer']['h'][str(i)] for i in range(layers_selected)}, # 仅使用指定层数
'ln_f': params['transformer']['ln_f'],
}}
return new_params
def get_token_ids():
return tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')
model_params=alice(get_model_params)()
# input_token_ids=bob(get_token_ids)()
input_token_ids=alice(get_token_ids)()
print("step1")
#device=spu
#model_params_, input_token_ids_ = model_params.to(device), input_token_ids.to(device)
print("step2")
output_token_ids = alice(text_generation)(input_token_ids, model_params)
# output_token_ids = spu(text_generation)(input_token_ids_, model_params_)
(7) 检查SPU的输出,在SPU上运行GPT-2推理非常简单。可以执行reveal显示SPU生成的文本。
outputs_ids=sf.reveal(output_token_ids)
print('-' * 65 + '\nRun on CPU:\'n' + '-' * 65)
print(tokenizer.decode(outputs_ids[0], skip_special_tokens=True))
print('-'*65)
可以发现,SPU生成的文本与CPU生成的文本是完全一致的!
后续对内存进行了扩增后,终于可以跑通完整的spu-gpt2的推理,但是计算还挺耗时的,需要有点耐心。结果符合预期,密文模型与明文一致。