联邦学习,server权重文件的传输与模型更新

文章详细描述了联邦学习中的模型初始化、权重文件传输、模型聚合以及更新过程,强调了加密技术在保证安全性和数据隐私的重要性,以GitHub上的multi手写数字识别项目为例,展示了如何使用Keras构建模型并进行迭代优化。
摘要由CSDN通过智能技术生成

在联邦学习中,参与方之间需要传输权重文件和更新模型。下面是一种常见的方法来实现权重文件的传输和模型的更新:

1. 初始化模型:选择一个主机或协调者作为主服务器,并在该主机上初始化一个模型。这个模型将用作联邦学习的基础。

2. 当前权重文件的传输:参与方将其当前的权重文件传输给主服务器。可以使用加密技术来确保传输的安全性。

3. 模型聚合:主服务器接收到参与方的权重文件后,使用一种聚合算法(如FedAvg)将所有权重文件聚合起来,生成一个新的模型。聚合算法可以根据参与方的贡献进行加权聚合,以保持数据隐私和参与方的权益。

4. 更新传输:主服务器将新的模型传输回参与方,以便参与方可以使用更新后的模型进行后续的训练和推理。

5. 重复迭代:重复执行步骤2-4,直到达到预定的迭代次数或准则。每迭代一轮,参与方将更新模型后的权重文件传输回主服务器进行聚合,主服务器再将更新后的模型传输回参与方。

需要注意的是,在联邦学习中,传输和更新的频率可以根据具体的应用场景和网络带宽来决定。传输和更新的频率较低可以减少通信开销,但可能会影响模型的收敛速度。因此,需要根据具体情况权衡。另外,为了确保传输的安全性,可以使用加密和身份验证等技术来保护传输的数据。

本文仍以GitHub上multi手写数字识别的代码为例。

GitHub - ZeroWangZY/federated-learning: Everything about Federated Learning (papers, tutorials, etc.) -- 联邦学习

1.初始化模型

首先,打开fl_服务器端文件,文件中定义了三个大类,其中包含各种函数。

第一个:全局模型类

第二,GlobalModel_MNIST_CNN类

三,联邦服务类

文末有if语句,执行联邦学习通信类,参数有三:GlobalModel_MNIST_CNN类,地址和端口号。

由此,程序开始,虽然先调用通信类,但是通信类入口的第一行却是如下代码:

    def __init__(self, global_model, host, port):
        self.global_model = global_model()

要拿到全局模型还要去第一个参数指定的位置,所以跳转到GlobalModel_MNIST_CNN类

然而此类之中仍然有个调用,对象是:

GlobalModel
class GlobalModel_MNIST_CNN(GlobalModel):
    def __init__(self):
        super(GlobalModel_MNIST_CNN, self).__init__()

程序从底部攀升到最开始的第一大类。我们来看其中首先要执行的代码。

class GlobalModel(object):
    def __init__(self):
        self.model = self.build_model()
由于使用了super,self的指向,build_model指的是应该是GlobalModel_MNIST_CNN类的函数。我们急转直下回到第二个类里查看代码。
    def build_model(self):
        # ~5MB worth of parameters
        model = Sequential()
        model.add(Conv2D(32, kernel_size=(3, 3),
                         activation='relu',
                         input_shape=(28, 28, 1)))
        model.add(Conv2D(64, (3, 3), activation='relu'))
        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Dropout(0.25))
        model.add(Flatten())
        model.add(Dense(128, activation='relu'))
        model.add(Dropout(0.5))
        model.add(Dense(10, activation='softmax'))

        model.compile(loss=keras.losses.categorical_crossentropy,
                      optimizer=keras.optimizers.Adadelta(),
                      metrics=['accuracy'])
        return model

这部分通过keras创建了一个模型。并返回到全局模型类里的self.model变量。

而后暂时不用来回跑了,下一行代码将模型的权重和损失等信息创建保存下来。

self.current_weights = self.model.get_weights()
self.train_losses = []
self.valid_losses = []
self.train_accuracies = []
self.valid_accuracies = []

至此,初始化模型完毕,模型的联邦学习交互调优开始。

2.当前权重文件的传输

文件创建完毕后别忘了只是暂时被调用的。干完活还是要回到通信类里来。

上半部分的通信接口不多赘述,我们主要看training state部分

current_round表示当前轮数

json.dumps() 将字典转换为 JSON 格式的字符串,把模型和权重损失等数据拿下来。

在然后定义了一个更新函数:

def handle_client_update(data):

取出当前模型和客户端训练的模型评估并更新当前模型

# tolerate 30% unresponsive clients
                if len(self.current_round_client_updates) > FLServer.NUM_CLIENTS_CONTACTED_PER_ROUND * .7:
                    self.global_model.update_weights(
                        [x['weights'] for x in self.current_round_client_updates],
                        [x['train_size'] for x in self.current_round_client_updates],
                    )
                    aggr_train_loss, aggr_train_accuracy = self.global_model.aggregate_train_loss_accuracy(
                        [x['train_loss'] for x in self.current_round_client_updates],
                        [x['train_accuracy'] for x in self.current_round_client_updates],
                        [x['train_size'] for x in self.current_round_client_updates],
                        self.current_round
                    )




评估:

                    if self.global_model.prev_train_loss is not None and \
                            (self.global_model.prev_train_loss - aggr_train_loss) / self.global_model.prev_train_loss < .01:
                        # converges
                        print("converges! starting test phase..")
                        self.stop_and_eval()

将当前数据发送回客户端

@self.socketio.on('client_eval')
        def handle_client_eval(data):
            if self.eval_client_updates is None:
                return
            print("handle client_eval", request.sid)
            print("eval_resp", data)
            self.eval_client_updates += [data]

            # tolerate 30% unresponsive clients
            if len(self.eval_client_updates) > FLServer.NUM_CLIENTS_CONTACTED_PER_ROUND * .7:
                aggr_test_loss, aggr_test_accuracy = self.global_model.aggregate_loss_accuracy(
                    [x['test_loss'] for x in self.eval_client_updates],
                    [x['test_accuracy'] for x in self.eval_client_updates],
                    [x['test_size'] for x in self.eval_client_updates],
                );
                print("\naggr_test_loss", aggr_test_loss)
                print("aggr_test_accuracy", aggr_test_accuracy)
                print("== done ==")
                self.eval_client_updates = None  # special value, forbid evaling again

否则开启下一轮训练:

    def train_next_round(self):
        self.current_round += 1
        # buffers all client updates
        self.current_round_client_updates = []

        print("### Round ", self.current_round, "###")
        client_sids_selected = random.sample(list(self.ready_client_sids), FLServer.NUM_CLIENTS_CONTACTED_PER_ROUND)
        print("request updates from", client_sids_selected)

        # by default each client cnn is in its own "room"
        for rid in client_sids_selected:
            emit('request_update', {
                    'model_id': self.model_id,
                    'round_number': self.current_round,
                    'current_weights': obj_to_pickle_string(self.global_model.current_weights),

                    'weights_format': 'pickle',
                    'run_validation': self.current_round % FLServer.ROUNDS_BETWEEN_VALIDATIONS == 0,
                }, room=rid)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值