目录
一、前言
1. 论文地址:
因为电脑性能有限,所以把四种聚类方式迭代次数都降低到迭代100次,而且只在MNist上,CIFAR10上跑起来也巨慢。官方给的代码跑不通,所以自己就改写了一下。
二、论文内容概要
因为目前只完整地跑完了MNist数据集上的实验,因此暂时先介绍到算法1。
1. 论文背景:
1.1 存在问题或现象:
1)抽样方法有偏倚
2)在服务器-客户端通信和训练的收敛稳定性方面不是最佳
1.2 论文提出的方法特点:
1) 聚类抽样选出的客户机具备更好的客户代表性
2)减少客户端在FL中随机聚集权重的差异(方差)
3)在客户端无需额外操作,可无缝集成至标准FL
4)与现有方法和技术兼容达到隐私增强
5)通过模型压缩减少通信量
2. 已有解决方案
1)FedAvg算法——随机选择m个客户端采样,对这m个客户端的梯度更新进行平均以形成全局更新同时用当前全局模型替换未采样的客户端
优点:相对于FedSGD在相同效果情况下,通讯成本大大降低
缺点:最终的模型是有偏倚的,不同于预期的每个客户端确定性聚合后的模型。
2)多项式分布抽样(MD抽样)算法——客户端抽样的概率对应于他们的相对样本量
优点:
(1)客户端抽样无偏性;
(2)通信量小(FedAvg和MD抽样是服务端-客户端通信最少的唯二方案)
缺点:
(1)仍然可能导致客户选择上有大的差异——选择客户端替换全局模型的次数差异
(2)这种差异导致了FL收敛性变化很大——在non-iid情况下,抽样的客户端都是基于自身数据分布改进全局模型,而未被抽样的客户端的全局模型则被直接替换
(3)损害了非抽样客户端的数据特异性
3. 论文方法
聚类采样方法:
1)Algorithm_1: sample size——基于样本大小的聚类采样聚合客户端的实现方法,该方法减少了客户端聚合权重的方差
2)Algorithm_2: models similarity——基于模型相似性,根据代表性梯度,将客户聚类,使得采样的客户端更具有代表性
优点:
1)增加了在全局模型中客户端的代表性,具有唯一分布的客户端,更有可能被采样,
2)并有可能导致更平滑、更快速的FL收敛,
3)两种方法都具备无偏性。
3)算法1——基于样本量大小的聚类
1) 基于样本数量的聚类抽样
参考[1]Blog 文中写的。
三、实验
1. 实验设置
参考[1]Blog 文中写的。
2. 代码
def get_clusters_with_alg1(n_sampled: int, weights: np.array):
"Algorithm 1"
epsilon = int(10 ** 6)
# associate each client to a cluster
augmented_weights = np.array([w * n_sampled * epsilon for w in weights])
ordered_client_idx = np.flip(np.argsort(augmented_weights))
n_clients = len(weights)
distri_clusters = np.zeros((n_sampled, n_clients)).astype(int)
k = 0
for client_idx in ordered_client_idx:
while augmented_weights[client_idx] > 0:
sum_proba_in_k = np.sum(distri_clusters[k])
u_i = min(epsilon - sum_proba_in_k, augmented_weights[client_idx])
#u_i = augmented_weights[client_idx]
distri_clusters[k, :client_idx] = u_i
sum_proba_k = np.sum(distri_clusters[k])
if sum_proba_k == client_idx * int(augmented_weights[client_idx]):
augmented_weights[client_idx] += -u_i
k += 1
distri_clusters = distri_clusters.astype(float)
for l in range(n_sampled):
distri_clusters[l] /= np.sum(distri_clusters[l])
return distri_clusters
参考官网代码出现两个问题:
1. k值始终无法增加,原因是初始的epsilon设置的太大了,不过也可能是我电脑问题,sum_proba_k=np.sum(distri_clusters[k])时,无法表示int64的数。
2. n_sampled给定的是10,但是client_idx循环次数远超过10次(100次)。
3. 修改后试验结果和论文结果对比:
1)代码结果:
2) 论文中结果:
训练次数比较少,不过大致看起来比较像,可能我改的也有问题,后面再继续看下无偏性和其他的。CIFAR10数据集的跑了α=0.001的,但是跑起来太慢了。
参考:
[1]Blog: 联邦学习——基于聚类抽样进行客户选择_联邦学习小白-CSDN博客https://blog.csdn.net/weixin_42534493/article/details/119330027