LEAF : A Benchmark for Federated Settings
LEAF 是联邦设置中的一个基准框架,适用于联邦学习、元学习、多任务学习和在设备上学习
一、数据集
首先介绍LEAF中涉及的数据集:
FEMNIST:用于图像分类
Shakespeare:下一个字符预测
Twitter:用于情感分析
Celeba:图像分类
合成数据:用于分类任务
Reddit:语言模型
二、安装
官方说的python环境3.5一下,我直接用的3.8的环境来跑,tensorflow-gpu 2.11.0
然后其他包直接pip install -r requirements.txt
三、使用示例
Twitter Sentiment Analysis
研究在使用 LEAF 框架下用 FedAvg 算法进行训练时,改变每个用户(用于训练)的最小样本数对模型准确性的影响。使用 Sentiment140 数据集(包含 160 万条推文),训练一个具有交叉熵损失的 2 层 LSTM 模型,并使用预训练的 GloVe 嵌入。
可以使用官方提供的脚本直接执行:leaf/paper_experiments $> ./sent140.sh
执行命令:sh sent140.sh
该脚本用于针对不同的最小样本计数执行实验,
先决条件
由于此实验需要预训练的词嵌入,所以建议运行 models/sent140/get_embs.sh 文件,该文件获取 300 维的预训练 GloVe 向量
执行命令:sh get_embs.sh
运行完之后,数据存储在 models/sent140/embs.json
![在这里插入图片描述](https://img-blog.csdnimg.cn/b8f1d5165fd547f984d69fe2bd811fed.png
数据集获取和预处理
LEAF 包含强大的脚本,用于获取数据并将其转换为 JSON 格式以便于使用。此外,这些脚本还能够从数据集中进行子采样,并将数据集拆分为训练集和测试集。
对于paper中的实验,作为第一步,我们将在 80-20 的训练/测试拆分中使用 50% 的数据集,并且我们将丢弃所有推文少于 10 条的用户。以下命令显示了如何完成此操作(本例中的 --spltseed 标志用于启用数据集的可重现生成)leaf/data/sent140/ $> ./preprocess.sh
sh preprocess.sh --sf 0.5 -t sample -s niid --tf 0.8 -k 3 --spltseed 1549775860
运行此脚本后,data/sent140/data 目录应包含 train/ 和 test/ 目录
执行模型
模型存放在 models/sent140/stacked_lstm.py,为了使用 FedAvg 训练这个模型,每轮有 2 个客户端,持续 10 轮,我们执行以下命令:leaf/models $>下
python main.py -dataset sent140 -model stacked_lstm -lr 0.0003 --clients-per-round 2 --num-rounds 10
或者,传递 -t small 代替后两个标志可提供相同的功能(如 models/baseline_constants.py 文件中所定义)。
全部的训练过程没有跑完,由于没有用GPU来进行训练,所以速度比较慢些
收集指标
执行上述命令会将系统和统计指标写入 leaf/models/metrics/stat_metrics.csv 和 leaf/models/metrics/sys_metrics.csv - 因为每次运行都会覆盖这些指标,所以强烈建议将生成的指标文件存储在不同的位置。
要试验不同的最小样本设置,使用不同的 -k 标志重新运行预处理脚本。可以使用存储库根目录中的 plots.py 文件生成下面显示的图。(不过这个代码没有找到)
结果与分析
执行此实验后,我们看到,虽然中位数性能仅随着数据不足的用户(即 k = 3)略有下降,但第 25 个百分位数(框底部)会急剧下降。
官方图:
TensorFlow 1.0和TensorFlow 2.0产生的冲突
① module ‘tensorflow‘ has no attribute ‘set_random_seed‘
解决:将tf.set_random_seed(1)替换为 tf.random.set_seed(1)
② from tensorflow.contrib import rnn 时 No module named ‘tensorflow.contrib’
解决:改为 from tensorflow.python.ops import rnn
③ module ‘tensorflow‘ has no attribute ‘logging‘
解决:将tf.logging替换成tf.compat.v1.logging
④ module ‘tensorflow’ has no attribute ‘reset_default_graph’
解决:tf.reset_default_graph() --------> tf.compat.v1.reset_default_graph()
⑤ module ‘tensorflow’ has no attribute ‘set_random_seed’
module ‘tensorflow’ has no attribute ‘placeholder’
该类解决方法全部参考④
⑥module ‘tensorflow.python.ops.rnn’ has no attribute ‘MultiRNNCell’
解决:改为 tf.compat.v1.nn.rnn_cell.MultiRNNCell
module ‘tensorflow.python.ops.rnn’ has no attribute ‘BasicLSTMCell’
解决:改为 tf.compat.v1.nn.rnn_cell.BasicLSTMCell
module ‘tensorflow._api.v2.nn’ has no attribute ‘softmax_cross_entropy_with_logits_v2’
解决:改为tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2