FedMl-master之FedAvg算法的api学习笔记
弃坑pysyft之后,改学FedML了!
来自于这篇文章:FedML: A Research Library and Benchmark for Federated Machine Learning。
附上github的地址:https://github.com/FedML-AI/FedML
这里以FedAvg算法为例。
1、首先来回顾一下FedAvg算法。
Server有一个初始化的模型Wt,然后它发送给所有的client。client收到后,利用梯度下降算法更新得到一个Wt+1,发回给server。server收到所有client发送的Wt+1之后,做一个简单的加权平均,得到一个最新的Wt+1’。server继续发送Wt+1’给所有的client,以此类推不断更新。
2、它的代码入口在fedml_experiments/distributed/fedavg下,对应的 api 在 fedml_api/distributed/fedavg里。
打开 FedAvgAPI.py 可以看到有两个函数,一个 init_server,一个 init_client。
3、先来看看server。它调用了一个ServerManager的api,他是一个很方便收发消息的api。
打开FedAvgServerManager.py可以看到具体的函数。比如:
1)handle_message_receive_model_from_client:这个函数里的aggregator就是存储client发来的参数以及client的id。有了这个之后,aggregator每次都会check是不是所有的client都发来了参数,如果都发送了,就会启动aggregation。aggregation结束之后进行client的下一轮采样。
2)send_message_init_config:这个api是server最开始发送消息①的初始化。
3)send_message_sync_model_to_client:当要把合并完的Wt+1’发送给client时,就会调用这个api来发送信息。
这样server端的算法逻辑都完成了。
4、然后是Client。它调用的是FedAVGClientManager的api。打开FedAvgClientManager.py可以看到具体的函数。和server不一样的是,他有两个信息的handler,一个是MSG_TYPE_S2C_INIT_CONFIG,对于server发的第①个消息的接收,信息接收完之后,会调用一个train,得到最新的weight。;另一个是MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT,对于server发的合并之后的Wt+1’的接收,接收完之后进行更新、再train。
5、具体的trainer在FedAVGTainer.py,就是一个简单的pytorch的训练。aggregation在FedAVGAggregator.py,就是3.1)讲的。
后续依据自己的idea,在这基础之上改代码就可以!
环境安装请移步:https://blog.csdn.net/weixin_43952176/article/details/118655805