在联邦学习中,参与方之间需要传输权重文件和更新模型。下面是一种常见的方法来实现权重文件的传输和模型的更新:
1. 初始化模型:选择一个主机或协调者作为主服务器,并在该主机上初始化一个模型。这个模型将用作联邦学习的基础。
2. 当前权重文件的传输:参与方将其当前的权重文件传输给主服务器。可以使用加密技术来确保传输的安全性。
3. 模型聚合:主服务器接收到参与方的权重文件后,使用一种聚合算法(如FedAvg)将所有权重文件聚合起来,生成一个新的模型。聚合算法可以根据参与方的贡献进行加权聚合,以保持数据隐私和参与方的权益。
4. 更新传输:主服务器将新的模型传输回参与方,以便参与方可以使用更新后的模型进行后续的训练和推理。
5. 重复迭代:重复执行步骤2-4,直到达到预定的迭代次数或准则。每迭代一轮,参与方将更新模型后的权重文件传输回主服务器进行聚合,主服务器再将更新后的模型传输回参与方。
需要注意的是,在联邦学习中,传输和更新的频率可以根据具体的应用场景和网络带宽来决定。传输和更新的频率较低可以减少通信开销,但可能会影响模型的收敛速度。因此,需要根据具体情况权衡。另外,为了确保传输的安全性,可以使用加密和身份验证等技术来保护传输的数据。
本文仍以GitHub上multi手写数字识别的代码为例。
1.初始化模型
首先,打开fl_服务器端文件,文件中定义了三个大类,其中包含各种函数。
第一个:全局模型类
第二,GlobalModel_MNIST_CNN类
三,联邦服务类
文末有if语句,执行联邦学习通信类,参数有三:GlobalModel_MNIST_CNN类,地址和端口号。
由此,程序开始,虽然先调用通信类,但是通信类入口的第一行却是如下代码:
def __init__(self, global_model, host, port):
self.global_model = global_model()
要拿到全局模型还要去第一个参数指定的位置,所以跳转到GlobalModel_MNIST_CNN类
然而此类之中仍然有个调用,对象是:
GlobalModel
class GlobalModel_MNIST_CNN(GlobalModel):
def __init__(self):
super(GlobalModel_MNIST_CNN, self).__init__()
程序从底部攀升到最开始的第一大类。我们来看其中首先要执行的代码。
class GlobalModel(object):
def __init__(self):
self.model = self.build_model()
由于使用了super,self的指向,build_model指的是应该是GlobalModel_MNIST_CNN类的函数。我们急转直下回到第二个类里查看代码。
def build_model(self):
# ~5MB worth of parameters
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
activation='relu',
input_shape=(28, 28, 1)))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))
model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.Adadelta(),
metrics=['accuracy'])
return model
这部分通过keras创建了一个模型。并返回到全局模型类里的self.model变量。
而后暂时不用来回跑了,下一行代码将模型的权重和损失等信息创建保存下来。
self.current_weights = self.model.get_weights()
self.train_losses = []
self.valid_losses = []
self.train_accuracies = []
self.valid_accuracies = []
至此,初始化模型完毕,模型的联邦学习交互调优开始。
2.当前权重文件的传输
文件创建完毕后别忘了只是暂时被调用的。干完活还是要回到通信类里来。
上半部分的通信接口不多赘述,我们主要看training state部分
current_round表示当前轮数
json.dumps() 将字典转换为 JSON 格式的字符串,把模型和权重损失等数据拿下来。
在然后定义了一个更新函数:
def handle_client_update(data):
取出当前模型和客户端训练的模型评估并更新当前模型
# tolerate 30% unresponsive clients
if len(self.current_round_client_updates) > FLServer.NUM_CLIENTS_CONTACTED_PER_ROUND * .7:
self.global_model.update_weights(
[x['weights'] for x in self.current_round_client_updates],
[x['train_size'] for x in self.current_round_client_updates],
)
aggr_train_loss, aggr_train_accuracy = self.global_model.aggregate_train_loss_accuracy(
[x['train_loss'] for x in self.current_round_client_updates],
[x['train_accuracy'] for x in self.current_round_client_updates],
[x['train_size'] for x in self.current_round_client_updates],
self.current_round
)
评估:
if self.global_model.prev_train_loss is not None and \
(self.global_model.prev_train_loss - aggr_train_loss) / self.global_model.prev_train_loss < .01:
# converges
print("converges! starting test phase..")
self.stop_and_eval()
将当前数据发送回客户端
@self.socketio.on('client_eval')
def handle_client_eval(data):
if self.eval_client_updates is None:
return
print("handle client_eval", request.sid)
print("eval_resp", data)
self.eval_client_updates += [data]
# tolerate 30% unresponsive clients
if len(self.eval_client_updates) > FLServer.NUM_CLIENTS_CONTACTED_PER_ROUND * .7:
aggr_test_loss, aggr_test_accuracy = self.global_model.aggregate_loss_accuracy(
[x['test_loss'] for x in self.eval_client_updates],
[x['test_accuracy'] for x in self.eval_client_updates],
[x['test_size'] for x in self.eval_client_updates],
);
print("\naggr_test_loss", aggr_test_loss)
print("aggr_test_accuracy", aggr_test_accuracy)
print("== done ==")
self.eval_client_updates = None # special value, forbid evaling again
否则开启下一轮训练:
def train_next_round(self):
self.current_round += 1
# buffers all client updates
self.current_round_client_updates = []
print("### Round ", self.current_round, "###")
client_sids_selected = random.sample(list(self.ready_client_sids), FLServer.NUM_CLIENTS_CONTACTED_PER_ROUND)
print("request updates from", client_sids_selected)
# by default each client cnn is in its own "room"
for rid in client_sids_selected:
emit('request_update', {
'model_id': self.model_id,
'round_number': self.current_round,
'current_weights': obj_to_pickle_string(self.global_model.current_weights),
'weights_format': 'pickle',
'run_validation': self.current_round % FLServer.ROUNDS_BETWEEN_VALIDATIONS == 0,
}, room=rid)