Pysyft 实现真实场景下联邦学习python3代码示例

Pysyft 实现真实场景下联邦学习python3代码示例

           \,\,\,\,\,\,\,\,\,\, 本教程使用pysyft联邦学习python3框架,基本实现了真实场景下多机器联邦学习,代码实现中采用Serverworker和Clientworker,而非Virtualworker。

1. 技术背景

           \,\,\,\,\,\,\,\,\,\, 联邦机器学习又名联邦学习,联合学习,联盟学习。联邦机器学习是一个机器学习框架,能有效帮助多个机构在满足用户隐私保护、数据安全和政府法规的要求下,进行数据使用和机器学习建模[百度百科]。简单的说就是在分布的数据不公布的情况下共同训练第三方模型,如图1所示。
联邦学习框架图

图1. 联邦学习示意图

2. 运行环境

           \,\,\,\,\,\,\,\,\,\, 本文使用纯pysyft框架,搭建基本的联邦学习python3代码框架,不包含加密处理和第三方梯度平均等复杂操作。系统linux和windows10均可,关键包版本如图2所示。在阅读本文之前请移步学习pysyft的基本操作:pysyft教程
在这里插入图片描述

图2. 联邦学习示意图

3. python3代码示例

  • 在pysyft的实现中,数据拥有者(企业AB)被视为服务器,需要部署并开启ServerWorker,并将带标签的数据上传到该worker,代码如下:
'''
--------------------------------------------------------
@File    :   server.py    
@Contact :   1183862787@qq.com or liuwang20144623@whu.edu.cn
@License :   (C)Copyright 2017-2018, CS, WHU
@Modify Time : 2020/6/16 20:42     
@Author      : Liu Wang    
@Version     : 1.0   
@Desciption  : None
--------------------------------------------------------  
'''
import torch
import syft as sy
from syft.workers.websocket_server import WebsocketServerWorker
import sys

try:
    host = sys.argv[1]
    id = sys.argv[2]
    port = sys.argv[3]
    print(host, id, port)
except Exception as e:
    host, id, port = None, None, None
    print(str(e))
    print('run the server by: "python server.py host id port"')
    print('for example: "python server.py localhost server1 8182"')
    exit(-1)

hook = sy.TorchHook(torch)
server_worker = WebsocketServerWorker(host=host,  # host="192.168.2.101", # the host of server machine
                                      hook=hook, id=id, port=port)
# hook = sy.TorchHook(torch, local_worker=server_worker)

# data in server
x = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True).tag("toy", "data")
y = torch.tensor([[0],[0],[1],[1.]], requires_grad=True).tag("toy", "target")
# x.private, x.private = True, True

x_ptr = x.send(server_worker)
y_ptr = y.send(server_worker)
print(x_ptr, y_ptr)

# x = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=False)
# y = torch.tensor([[0],[0],[1],[1.]], requires_grad=False)
# server_worker.add_dataset(sy.BaseDataset(data=x, targets=y), key="vectors")

print('>>> server_worker:', server_worker)
print('>>> server_worker.list_objects():', server_worker.list_objects())
print('>>> server_worker.list_tensors():', server_worker.list_tensors())

server_worker.start()  # Might need to interrupt with `CTRL-C` or some other means

print('>>> server_worker.list_objects()', server_worker.list_objects())
print('>>> server_worker.objects_count()', server_worker.objects_count())
print('>>> server_worker.list_tensors():', server_worker.list_tensors())
print('>>> server_worker.host', server_worker.host)
print('>>> server_worker.port', server_worker.port)
  • 先开启server,然后再运行client。在pysyft的实现中,模型拥有者(数据分析者)被视为客户端,需要部署多个ClientWorker与ServerWorker一一对应,并根据服务端数据的标签查询数据并返回训练数据的指针,训练过程与教程中示例类。代码如下:
'''
--------------------------------------------------------
@File    :   client.py    
@Contact :   1183862787@qq.com or liuwang20144623@whu.edu.cn
@License :   (C)Copyright 2017-2018, CS, WHU
@Modify Time : 2020/6/16 20:42     
@Author      : Liu Wang    
@Version     : 1.0   
@Desciption  : None
--------------------------------------------------------  
'''
import torch
from torch import optim
import syft
# from syft.grid.public_grid import PublicGridNetwork
from syft.workers.websocket_client import WebsocketClientWorker
hook = syft.TorchHook(torch)

def train(model, datasets, ITER=20)->torch.nn.Module:
    """
    :param model: the torch model
    :param datasets: the datasets pointers about server workers
            with the format as [(data_ptr, target_ptr), (data_ptr, target_ptr), ...]
    :param ITER: the number of iteration
    :return:
    """
    model_c = model.copy()
    # Training Logic
    for iter in range(ITER):
        for data, target in datasets:
            # 1) send model to correct worker
            model_c = model_c.send(data.location)
            # 2) Call the optimizer for the worker using get_optim
            opt = optim.SGD(params=model_c.parameters(),lr=0.1)
            # 3) erase previous gradients (if they exist)
            opt.zero_grad()
            # 4) make a prediction
            pred = model_c(data)
            # 5) calculate how much we missed
            loss = ((pred - target)**2).sum()
            # 6) figure out which weights caused us to miss
            loss.backward()
            # 7) change those weights
            opt.step()
            # 8) get model (with gradients)
            model_c = model_c.get()
            # 9) print our progress
            print(data.location.id, loss.get())
    return model_c

if __name__ == '__main__':
    # create a client workers mapping to the server workers in remote machines
    remote_client_1 = WebsocketClientWorker(
        host='localhost',
        # host = '192.168.0.102', # the host of remote machine, the same as the Server host
        hook=hook,
        id='server1',
        port=8182)
    remote_client_2 = WebsocketClientWorker(
        host='localhost',
        # host = '192.168.0.102', # the host of remote machine, the same as the Server host
        hook=hook,
        id='server2',
        port=8183)
    remote_clients_list = [remote_client_1, remote_client_2]
    print('>>> remote_client_1', remote_client_1)
    print('>>> remote_client_2', remote_client_2)

    # get the data pointers which point to the real data in remote machines for training model
    datasets = []
    for remote_client in remote_clients_list:
        data = remote_client.search(["toy", "data"])[0]
        target = remote_client.search(["toy", "target"])[0]
        print('>>>data: ', data)
        print('>>>target: ', target)
        datasets.append((data, target))
    # exit(0)
    # define torch model
    model = torch.nn.Linear(2, 1)
    print('>>> untrained model: ', model.state_dict())
    # train model
    trained_model = train(model, datasets, ITER=10)
    print('>>> trained model: ', trained_model.state_dict())

4. 总结

          \,\,\,\,\,\,\,\,\, 本教程使用pysyft联邦学习python3框架,基本实现了多机器联邦学习。但本教程示例存在严重缺陷:在获取数据指针后,客户端能够通过指针的.get()函数从服务端拿到数据。为了避免这种情况,需要用到pygrid包,并设置数据的权限。有时间将后续更新。

  • 11
    点赞
  • 70
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 33
    评论
要在 Python实现联邦学习,您可以使用 TensorFlow Federated(TFF)这个开源框架。 以下是一个简单的 Python 示例,它演示了如何使用 TFF 实现联邦学习: ```python import tensorflow as tf import tensorflow_federated as tff # 定义一个简单的 Keras 模型 def create_keras_model(): return tf.keras.models.Sequential([ tf.keras.layers.Dense( 10, activation=tf.nn.softmax, input_shape=(784,)) ]) # 定义一个简单的联邦学习数据集 emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data() # 转换数据集以进行联邦学习 def preprocess(dataset): def element_fn(element): return (tf.reshape(element['pixels'], [-1]), element['label']) return dataset.repeat(10).map(element_fn).shuffle(500).batch(20) preprocessed_train_data = preprocess(emnist_train) # 定义联邦学习模型 def model_fn(): keras_model = create_keras_model() return tff.learning.from_keras_model( keras_model, input_spec=preprocessed_train_data.element_spec, loss=tf.keras.losses.SparseCategoricalCrossentropy(), metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) # 定义联邦平均过程 federated_averaging = tff.learning.build_federated_averaging_process( model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02)) # 训练联邦学习模型 state = federated_averaging.initialize() for round_num in range(10): state, metrics = federated_averaging.next(state, [preprocessed_train_data]) print('round {:2d}, metrics={}'.format(round_num, metrics)) ``` 此示例演示了如何使用 TFF 实现基本的 MNIST 联邦学习。在此示例中,我们使用了 TensorFlow Federated 提供的 EMNIST 数据集,并使用 TensorFlow Keras 创建了一个简单的模型。我们将数据集转换为可用于联邦学习的格式,并使用 TFF 实现联邦平均过程来训练模型。
评论 33
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

六娃_lw

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值