0 前言
机器学习两大挑战:
- 数据安全难以得到保障,隐私数据泄露问题亟待解决。
- 网络安全隔离和行业隐私,不同行业、部门之间存在数据壁垒,导致数据形成“孤岛”无法安全共享,而仅凭各部门独立数据训练的机器学习模型性能无法达到全局最优化。
为了解决以上问题
谷歌提出联邦学习(FL,federated learning)技术,其通过将机器学习的数据存储和模型训练阶段转移至本地用户,仅与中心服务器交互模型更新的方式有效保障了用户的隐私安全。
1 什么是联邦学习?
传统的机器学习算法需要用户将源数据上传到高算力的云服务器上集中训练,这种方式导致了数据流向的不可控和敏感数据泄露问题。
Mcmahan 等在 2016 年提出联邦学习技术,允许用户在机器学习过程中既可以保护用户隐私,又能够无须源数据聚合形成训练数据共享。
联邦学习本质上是一种分布式的机器学习技术,其流程如图 1 所示。
多个客户端设备(如平板电脑、手机、物联网设备)和中心服务器(如服务提供商)的协调下,共同训练一个模型。在这一过程中,客户端负责利用本地数据进行模型训练,得到本地模型(local model),而中心服务器则负责将各客户端训练得到的本地模型进行加权聚合,从而形成全局模型(global model)。通过多轮迭代,这一过程最终得到一个效果接近于传统集中式机器学习的模型 ( w w w ),从而有效降低了传统机器学习中由于数据集中带来的隐私风险。
联邦学习的一次迭代过程如下:
- 客户端从服务器下载上一轮的全局模型 w t − 1 w_{t-1} wt−1 。
- 客户端 k k k 利用本地数据训练得到本地模型 w t , k w_{t,k} wt,k (第 k k k 个客户端在第 t t t 轮通信中的本地模型更新)。
- 各客户端将本地模型更新上传至中心服务器。
- 中心服务器接收各客户端的数据后进行加权聚合操作,得到全局模型 w t w_t wt(第 t t t 轮通信中的全局模型更新)。
联邦学习技术具有以下几个特点:
- 参与联邦学习的原始数据保留在本地客户端,与中心服务器交互的只是模型更新信息。
- 联邦学习的参与方共同训练出的模型 w w w 将被各方共享。
- 联邦学习最终的模型精度与集中式机器学习相似。
- 联邦学习参与方的训练数据质量越高,全局模型的精度越高。
2 联邦学习的算法原理
联邦学习的目的是在保证数据隐私的前提下,通过多个客户端设备协作训练一个全局模型 w \mathbf{w} w 。在这种方法中,数据存储和处理都是在本地客户端设备上完成的,只有模型更新的梯度信息需要上传至中心服务器。
目标函数
中心服务器的目标是优化全局模型的目标函数 F ( w ) F(w) F(w) ,这个目标函数通常定义为所有客户端设备的加权平均:
min w F ( w ) , F ( w ) = ∑ k = 1 m n k n F k ( w ) (1) \min_{w} F(w), \quad F(w) = \sum_{k=1}^{m} \frac{n_k}{n} F_k(w) \tag{1} wminF(w),F(w)=k=1∑mnnkFk(w)(1)
其中:
- m m m 是参与训练的客户端设备总数。
- n n n 是所有客户端数据量总和,即 n = ∑ k = 1 m n k n = \sum_{k=1}^{m} n_k n=∑k=1mnk 。
- n k n_k nk 是第 k k k 个客户端的数据量。
- F k ( w ) F_k(w) Fk(w) 是第 k k k 个设备的本地目标函数。
本地目标函数
每个客户端 k