目录
联邦学习概念
解决数据孤岛问题,保护隐私,本地数据,训练全局模型
常见算法
聚合算法:
FedAvg (广泛使用)
FedProx (改进的FedAvg,引入了正则化项来平衡本地模型和全局模型之间的权重,更好地处理数据不平衡和设备异构性问题)
FedOpt (针对联邦学习中的优化问题提出,采用递归的方式进行模型训练,引入动量和自适应学习率)
攻击算法:
数据污染攻击 本地数据集篡改或注入恶意样本(本地影响全局)
模型替换攻击 本地模型参数替换或修改(本地影响全局)
梯度泄露攻击 利用返回的梯度信息推断其他参与者数据特征或标签信息(反向推断)
防御算法:
梯度裁剪 限制每个参与者提交给服务器的梯度范数
Krum 多个预测结果中,选最具代表性作为最终结果,集成模型的误差。
Bulyan 基于Krum改进,增加了Trimmed Mean。
FedAvg算法详解
FedAvg
概念:加权平均聚合模型参数。将本地模型的参数上传到服务器,并对参数赋予权重(权重赋值标准是本地数据量大小),服务器计算所有模型参数的加权均值,然后将这个平均值广播回本地设备。迭代多次,直到收敛。
问题:数据不平衡问题:每个设备上传的模型参数的权重是根据设备上的本地数据量大小进行赋值的。可能会导致数据不平衡的问题,即数据量较小的设备对全局模型的贡献较小,影响泛化性能。
- 服务器初始化全局模型参数 $w_0$;
- 所有本地设备随机选择一部分数据集,并在本地计算本地模型参数 $w_i$;
- 所有本地设备上传本地模型参数 $w_i$ 到服务器;
- 服务器计算所有本地模型参数的加权平均值 $\bar{w}$,并广播到所有本地设备;
- 所有本地设备采用 $\bar{w}$ 作为本地模型参数的初始值,重复步骤2~4,直到全局模型收敛。
def FedAvg(self):
total_samples = sum(self.num_samples)'''变量'''
base = [0] * len(self.weights[0])'''创建一个列表'''
for i, client_weight in enumerate(self.weights):
for j, v in enumerate(client_weight):
base[j] += (self.num_samples[i] / total_samples * v.astype(np.float64))
return base
列表:
base / self.num_samples / self.weights[i](第i个客户端的权重列表,每一个元素都是一个列表或数组)([[3],[4],[6],[7],[2],...])(所以只有通过for遍历两次才能得到一个单纯的数字)
初始化:
total_samples
:所有客户端的总样本数。self.num_samples
是一个列表,包含每个客户端的样本数(该列表长度为参与训练的客户端个数)(归一化因子确保每个客户端的权重更新根据其样本数量进行加权)。
sum(self.num_samples)
:sum
函数计算了self.num_samples
列表中所有元素的总和。列表求和
base
:列表,用于存储聚合后的全局模型权重。其长度与单个客户端的权重列表相同,初始化为全零列表。
[0] * len(self.weights[0])
:创建了一个新列表,长度与self.weights[0]
相同,所有元素都被初始化为0。len(self.weights[0])
:这个计算了第一个客户端权重列表(或数组)的长度。由于联邦学习假设所有客户端的模型结构相同,因此这个长度也代表了所有客户端权重列表(或数组)的长度。
遍历客户端权重:
enumerate(self.weights):
遍历每个客户端的权重列表(client_weight
),同时获取其索引(i)
。
加权聚合:
对每个客户端的权重列表(
client_weight
),遍历其每个权重值(v
)。使用每一个客户端的样本数(self.num_samples[i]
)作为权重因子,对全局权重(base
)进行加权更新。每个客户端的权重值v
被转换为np.float64
类型(确保精度),(.astype(np.float64)将数组或序列中的数据类型转换为 64 位浮点数)
base[j] += (self.num_samples[i] / total_samples * v.astype(np.float64)):第i个客户端中第j个客户端的权重=第i个客户端的样本数除以所有客户端的总样本数*第j个客户端权重值,因为是+=,所以最终效果是base中第i个元素是第i个客户端的权重信息
返回聚合后的模型权重:函数返回base
,即聚合了所有客户端权重信息后的全局模型权重。
- 这段代码假设所有客户端的模型结构相同,即权重列表的长度一致。