隐语实训09-SML入门基于SPU迁移机器学习算法实践

一、32位浮点数

32位浮点数(Single Precision Floating Point)是一种用于表示实数的标准格式,由IEEE 754标准定义。

表示方法

32位浮点数由三部分组成:

  1. 符号位(S):1位,表示数值的正负。
  2. 指数位(E):8位,用于表示数值的范围。
  3. 尾数位(M):23位,表示有效数字。

其表示公式为:

( − 1 ) S × 1. M × 2 ( E − 127 ) ( − 1 ) S × 1. M × 2 ( E − 127 ) ( − 1 ) S × 1. M × 2 ( E − 127 ) (−1)S×1.M×2(E−127)(-1)^S \times 1.M \times 2^{(E-127)}(−1)S×1.M×2(E−127) (1)S×1.M×2(E127)(1)S×1.M×2(E127)(1)S×1.M×2(E127)

  • 符号位 S 决定数的正负,0表示正数,1表示负数。
  • 指数位 E 采用偏移量为127的表示方法,即实际指数为 E−127。
  • 尾数位 M 代表小数部分,实际有效数字为 1.M。

优缺点

优点

  • 范围广:可以表示非常大的数和非常小的数。
  • 精度高:对大多数应用场景下的计算精度需求都能满足。

缺点

  • 计算复杂:浮点运算相对耗时,硬件实现复杂。
  • 存储空间大:占用32位存储空间。

应用场景

32位浮点数广泛用于科学计算、图形处理、机器学习等需要高精度和大范围数值表示的领域。

二、8位定点数

8位定点数(Fixed Point)是一种用于表示小范围数值的表示方法,适用于嵌入式系统和资源受限的环境。

表示方法

8位定点数的表示方法有多种,常见的是 Q7 格式,即:

  • 符号位(S):1位,表示数值的正负。
  • 整数位:0位。
  • 小数位:7位,表示小数部分。

其表示公式为:

( − 1 ) S × ( M 27 ) ( − 1 ) S × ( M 2 7 ) ( − 1 ) S × ( 27 M ) (−1)S×(M27)(-1)^S \times \left(\frac{M}{2^7}\right)(−1)S×(27M) (1)S×(M27)(1)S×(27M)(1)S×(27M)

  • 符号位 S 决定数的正负,0表示正数,1表示负数。
  • 尾数位 M 直接表示小数部分。

优缺点

优点

  • 计算简单:定点数运算简单,硬件实现高效。
  • 存储空间小:仅占用8位存储空间,节省内存。

缺点

  • 范围有限:只能表示较小范围的数值。
  • 精度有限:小数位越多,能表示的范围越小。

应用场景

8位定点数常用于嵌入式系统、DSP(数字信号处理)和物联网设备中,这些场景对计算资源和存储空间要求严格,且数值范围和精度需求较低。

三、比较与选择

精度与范围

  • 32位浮点数:适用于需要高精度和大范围数值的应用场景,如科学计算和机器学习。
  • 8位定点数:适用于资源受限且数值范围和精度要求较低的场景,如嵌入式系统和简单的信号处理。

计算复杂性

  • 32位浮点数:计算复杂,硬件实现成本高。
  • 8位定点数:计算简单,硬件实现成本低。

存储需求

  • 32位浮点数:占用32位存储空间。
  • 8位定点数:占用8位存储空间,更节省内存。

实际应用

  • 32位浮点数:广泛用于需要高精度和广泛范围的领域,如科学计算、图形处理、机器学习等。
  • 8位定点数:广泛用于嵌入式系统、物联网设备和DSP等资源受限的领域。

四、SML实践

接下来实践一个代码,来展示如何使用 SecretFlow 库和 SPU(Secure Processing Unit)设备来执行隐私保护的计算任务。代码涵盖了网络配置、数学运算的模拟、数据加载和处理、以及网络模拟操作等多个方面。模拟网络条件变化前后对计算任务性能的影响

import secretflow as sf
import spu
import os
import numpy as np
import jax.numpy as jnp
import jax
import jax.lax
import spu.utils.simulation as spsim
import spu.spu_pb2 as spu_pb2
from functools import partial

network_conf ={
    "parties":{
        "alice":{
            "address":"alice:8000",
        },
        "bob":{
            "address":"bob:8000",
        },
    },
}


party = os.getenv("SELF_PARTY","alice")
sf.shutdown()
sf.init(
    address="127.0.0.1:6379",
    cluster_config={**network_conf,"self_party": party},
    log_to_driver=True,
)

!yum install -y iproute-tc

#we know that dk is wrong when |x| is very small
# Let us try it.(we only show part here.)
# define some test function and data used in simulation
#def test_square_and_sum_when_x_small(x):
#    return jnp.sum(jnp.square(x))

def compute_dk_func(x, eps=1e-6, iterations=100):
    result = x
    for _ in range(iterations):
        result = jnp.square(result)
        result = jax.lax.rsqrt(jnp.sum(result) + eps)
    return result
    
x = np.array([1e-5]* 10)

#First,we run SPU with simulator
#Indeed,simulation can be run within single node.
# a.run with CHEETAH
sim_che = spsim.Simulator.simple(2, spu_pb2.ProtocolKind.CHEETAH, spu_pb2.FieldType. FM64)
spsim.sim_jax(sim_che, test_square_and_sum_when_x_small)(x)

#b.run with ABY3
# this time,we alse print some profile stats.
config_aby = spu.RuntimeConfig(
    protocol=spu_pb2.ProtocolKind.ABY3,
    field=spu.FieldType.FM64,
    fxp_fraction_bits=18,
    enable_hal_profile=True,
    enable_pphlo_profile=True,
)
sim_aby=spsim.Simulator(3,config_aby)
print(spsim.sim_jax(sim_aby, test_square_and_sum_when_x_small)(x))

!tc qdisc del dev eth0 root


!ping -c 4 bob

!ping -c 4 alice

# Emulation should be run from source in SPU, so we use Secretflow here to do the efficiency experiments.
# You can use the similar trick for emulation directly in SPU.
def compute_dk_func(x,eps=1e-6):
    return jax.lax.rsqrt(jnp.sum(jnp.square(x))+ eps)

x = np.random.rand(1_000_000)


# SPU settings
cluster_def={
    'nodes':[
        {'party':'alice','id':'local:0','address': 'alice'+ ':12945'},
        {'party':'bob','id':'local:1','address':'bob'+':12945'},
    ],
    'runtime_config':{
        #SEMI2K support 2/3 PC,ABY3 only support 3PC, CHEETAH only support 2PC.
        # pls pay attention to size of nodes above, nodes size need match to Pc setting.
        'protocol':spu.spu_pb2.SEMI2K,
        'field':spu.spu_pb2.FM64
    },
}
alice_device = sf.PYU("alice")
bob_device = sf.PYU("bob")
spu_device =sf.SPU(cluster_def)

#first, load data to PYU
alice_data = alice_device(lambda x: x)(x)

#SPU may need some init, so we run this twice...
ret  = spu_device(compute_dk_func)(alice_data)
sf.reveal(ret);

#调整网络状况,限制带宽和延迟
!tc qdisc add dev eth0 root handle 1: tbf rate 100mbit burst 128kb limit 10000
!tc qdisc add dev eth0 parent 1:1 handle 10: netem delay 10msec limit 8000

!ping -c 4 bob

!ping -c 4 alice
#调整网络状况后再执行计算任务,此次任务执行时间应该变长了
ret = spu_device(compute_dk_func)(alice_data)
sf.reveal(ret);

!tc qdisc del dev eth0 root


接下来作一个jax.numpy.digitize 的MPC-friendly的实现,原课程的案例如下:

x=np.array([0.0, 0.2, 6.4, 3.0, 1.6, 12.0])
bins =np.array([0.0, 1.0, 2.5, 4.0, 10.0])
jnp.digitize(x,bins)


config_aby = spu.RuntimeConfig(
    protocol=spu_pb2.ProtocolKind.ABY3,
    field=spu.FieldType.FM64,
    fxp_fraction_bits=18,
    enable_hal_profile=True,
    enable_pphlo_profile=True,
)
sim_aby=spsim.Simulator(3,config_aby)
print(spsim.sim_jax(sim_aby, jnp.digitize)(x,bins))

# MPC-friendly jnp.digitize example
# Note: here,we only deal the case of `right=False', other cases are similar.
def my_digitize(x, bins):
    # vectorize
    com=x.reshape(*x.shape, -1)>= bins
    #count the number ofxthat exceeds bins
    return jnp.sum(com, axis=1)
print(spsim.sim_jax(sim_aby,my_digitize)(x,bins))

在这里插入图片描述

在原案例中的实现在每次迭代中都对整个数组进行平方和开方操作,这可能导致不必要的计算负担。现在使用广播和累计求和直接计算看下效果

def optimized_digitize(x, bins):
    # 使用广播和累积求和直接计算
    return jnp.sum(x[:, None] >= bins, axis=-1)
print(spsim.sim_jax(sim_aby,optimized_digitize)(x,bins))

任务时间上快了一些,通信没什么变化,还是有一定效果的

在这里插入图片描述

  • 55
    点赞
  • 32
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值