联邦元学习综述

源自:大数据

作者:张传尧  司世景  王健宗  肖京

“人工智能技术与咨询”  发布

摘 要

随着移动设备的普及,海量的数据在不断产生。数据隐私政策不断细化,数据的流动和使用受到严格监管。联邦学习可以打破数据壁垒,联合利用不同客户端数据进行建模。由于用户使用习惯不同,不同客户端数据之间存在很大差异。如何解决数据不平衡带来的统计挑战,是联邦学习研究的一个重要课题。利用元学习的快速学习能力,为不同数据节点训练不同的个性化模型来解决联邦学习中的数据不平衡问题成为一种重要方式。从联邦学习背景出发,系统介绍了联邦学习的问题定义、分类方式及联邦学习面临的主要问题。主要问题包括:隐私保护、数据异构、通信受限。从联邦元学习的背景出发,系统介绍了联邦元学习在解决联邦学习数据异构、通信受限问题及提高恶意攻击下鲁棒性方面的研究工作,对联邦元学习的工作进行了总结展望。

关键词

联邦学习 ; 元学习 ; 数据异构 ; 联邦元学习 ; 隐私保护

 引 言

随着移动设备的普及,海量的数据在不断产生,合理有效地利用这些数据成为重点研究方向。由于隐私政策的保护,很多数据不能被轻易地获取,数据间相互隔离,形成了一个个数据“孤岛”。如何建立数据“孤岛”间沟通的桥梁,打破数据之间的界限,成为一个热点问题。联邦学习为解决该问题提供了一个新的方向。

联邦学习在满足数据隐私要求、保护数据安全、遵守政府法规的前提下,进行数据的使用和建模,即通过只在各节点间传递模型参数,而不分享节点间数据的方式训练一个共享的数据模型[1]。许多早期的研究旨在在数据不公开的情况下分析和利用分布在不同所有者手中的数据。早在20世纪80年代,对加密数据进行计算的研究就已经展开,直到2016年,谷歌研究院[2]正式提出联邦学习这一术语,对分布式数据的隐私保护研究才开始归于一类。联邦学习成为解决数据隐私保护问题的一个有力工具。

在传统的机器学习中,通常需要大量的数据样本进行训练,才能获得一个较好的模型。例如在神经网络中,需要大量的标签数据进行模型训练,才能使模型具有良好的分类效果,并且一个训练好的神经网络模型往往只能解决某一类问题。在某些情况下,数据本身是稀缺的,大量的有标签数据是不容易获得的,往往只有少量的样本能够进行数据训练。人类可以通过少量的某一类动物的图片学习到这种动物的概念,再见到这种动物时能够很快地识别出来。这种通过少量样本图片快速学习到新概念的能力,对应机器学习中元学习的概念。元学习的训练目标是训练一个模型,这个模型只需要通过少量的数据和迭代训练就可以快速适应新的任务,即训练一个具有很强适应能力的模型[3]。元学习能够很好地解决训练数据不足的问题。元学习算法由两个部分构成:基础学习者和元学习者[2]。基础学习者在单个任务的水平上工作,其特征在于只有一小组标记的训练图像可用。元学习者从几个这样的情节中学习,目的是提高基础学习者在不同情节中的表现。一般认为元学习系统应当具有以下3个特征:拥有一个基础学习子系统;具有能够利用先前的经验获取知识的能力;能够动态地选择学习偏差。

元学习的早期研究工作主要集中在教育科学相关的领域,主要研究并控制自身的学习状态。随着机器学习的发展,元学习开始进入机器学习领域。元学习的第一个例子出现在20世纪80年代[2],参考文献[4]提出了一个描述何时可以动态调整学习算法归纳偏差,从而隐式地改变其假设空间元素顺序的框架。参考文献[5]提出具有两个“嵌套学习层”的元学习方法。元学习可以跨越多个问题进行经验的积累,以适应基础假设空间[3]。

考虑联邦学习在解决异构数据训练方面的需求和元学习在多任务模型上的良好表现,利用元学习训练一个个性化的联邦学习算法成为一种选择。现有的联邦学习[6]主要是利用不同的数据节点联合训练一个统一的全局模型,这种统一的全局模型不利于解决数据的非独立同分布问题。联邦元学习为不同的数据节点训练单独的数据模型,这种多模型的训练方式可以直接捕捉客户端间的数据不平衡关系,使它们很适合解决联邦学习的数据不平衡问题。

1 联邦学习简介

1.1  问题定义

联邦学习在满足数据隐私要求、保护数据安全、遵守政府法规的前提下,进行数据的使用和建模,即通过只在各节点间传递模型参数,而不分享节点间数据的方式训练一个共享的数据模型[1]。联邦学习不需要交换各数据节点间的数据,各节点间仅交换共享数据模型的参数,以保护用户的隐私安全。

定义n个数据拥有者{f1,f2,…,fn},不同数据拥有者 fi的本地目标用Fi(ω)表示,它们各自拥有自己的数据{D1,D2,…,Dn},并希望利用这些数据训练机器学习模型。传统的机器学习方法是利用数据D=D1∪D2∪…∪Dn 训练一个机器学习模型ωsum。在联邦学习中,服务器端使用聚合函数G(·)聚合来自不同数据拥有者的模型参数。数据拥有者在保护自身数据安全、互相不交换本地数据的情况下共同训练一个模型ωfed。联邦学习的全局目标定义如式(1)所示:

图片

  (1)

模型ωfed的精度vfed应当非常接近模型ωsum的精度vsum。如果存在非负实数δ使得式(2)成立:

图片

(2)

则称联邦学习算法具有δ精度损失。

1.2   联邦学习的训练过程

随着联邦学习研究的开展,各种各样的联邦学习框架被开发出来。例如微众银行的FATE已经覆盖了3种联邦学习:横向联邦学习、纵向联邦学习、联邦迁移学习[7]。谷歌开源的Tensor/IO已经可以较好地支持横向联邦学习。尽管不同的算法框架(例如PySyft、FFL-ERL、CrypTen、LEAF、TFF)[8]对联邦学习的支持不同,但是联邦学习的主要训练过程均可以分为以下4步。①中心服务器将最新的模型分发给各数据节点;②各数据节点利用本地数据更新模型;③各训练节点将更新的模型参数加密传送给中心服务器,中心服务器聚合各节点的参数,得到新的模型参数;④中心服务器将更新后的模型参数发送给各节点,节点更新本地模型参数,并进行下一轮训 练。联邦学习训练过程如图1所示。

1.3    联邦学习特点

联邦学习与传统机器学习存在很大不同,具体见表1。联邦学习的分布式环境设置导致不同数据节点的地理位置可能不同,用户的使用习惯存在差异,从而影响数据的分布。不同数据节点间是非独立同分布的,任何一个数据节点都不能代表整个数据集的分布。设备环境是否稳定也是影响联邦学习的一个重要因素,有限的网络通信速率要求找到一种合适的方式提高设备间的通信效率,同时还要避免因环境不稳定导致的设备随机加入与退出。隐私保护是联邦学习最基本的属性要求,当中间结果与数据结构一起暴露时,可能造成数据的泄露。因此如何解决数据非独立同分布问题,提高通信效率,如何进行隐私保护成为联邦学习的关键。

1.3.1    数据隐私保护

隐私性是联邦学习的基本属性,如果不能做到对数据的隐私进行有效保护,联邦学习将失去可靠性,不同的数据“孤岛”也不会将自己的数据贡献出来用于数据训练[9]。联邦学习在参数更新过程中,交换了工作的中间结果,因此不同数据方更容易受到推理攻击,敌对的参与方可以推断出训练数据子集的相关属性[7]。在数据交换时,隐私保护的方式有很多种,例如在机器学习期间通过加密机制下的参数交换来保护用户数据隐私[7],或者使用差分隐私的方式保护数据[10,11,12,13]。安全多方计算、安全聚合[14]也是常用的隐私保护手段。其中,使用差分隐私方式保护数据隐私的方法通过向数据加入噪声的方式掩盖真实的数据,但是加入的噪声可能会影响最终结果的准确度。如何确定加入的噪声量是一个值得研究的问题,加入的噪声太多会导致计算结果失去准确性,加入的噪声不足则导致隐私保护效果不好。

1.3.2    数据非独立同分布

身份、性格、环境的差异导致由用户产生的数据集可能存在很大的差异,训练样本并不是均匀随机地分布在不同的数据节点间的[15,16,17]。不平衡的数据分布可能导致模型在不同设备上的表现出现较大偏差。因此在进行联邦学习前,如何选取有效的数据集进行数据处理是一个重要的问题。要解决联邦学习中的数据非独立同分布问题,主要的思路有两种,一种是通过优化模型聚合的方法降低数据不平衡带来的影响,另一种是通过优化本地模型的更新过程解决联邦学习的统计挑战问题。参考文献[18]提出了一种基于迭代模型平均的深层网络联合学习方法,该方法对于不平衡和非独立同分布是稳健的。参考文献[15]提出通过每个设备上的类别分布和人口分布之间的地球移动者距离来量化数据集间的差异,并创建一个在所有边缘设备之间全局共享的数据子集来改进对非独立同分布数据的训练。

图片

图1   联邦学习训练过程

图片

表1   联邦学习与传统机器学习比较

1.3.3    通信环境受限

在联邦学习中,中心服务器与计算节点间的物理距离很远,通信成本较高[14],且由于计算节点环境的不稳定性,可能随时存在计算节点加入和退出的情况,因此联邦学习一般应选取网络环境稳定免费且计算节点空闲时进行。通信成本成为制约联邦训练的主要因素,因此如何对设备间的通信进行压缩是一个值得研究的问题,可以通过减小客户端传送到服务器的对象的大小、减小从服务器向客户端广播的模型大小、客户端从全局模型开始培训本地模型等方法降低对通信链路的要求[19]。参考文献[20]中给出两种降低上行链路通信成本的方法:结构化更新和草图更新。结构化更新直接从使用较少数量的变量(如低秩或随机掩码)的受限空间中学习更新;草图更新先模型更新,然后在发送到服务器之前,使用量化、随机旋转和二次采样的组合对其进行压缩。

1.4    联邦学习算法

联邦学习的更新过程主要分为服务器端更新和客户端更新两部分,按照算法对联邦学习改进的阶段,可以将联邦学习算法分为两类:基于服务器端聚合方法优化的算法和基于客户端优化的算法。

 1.4.1    基于服务器端聚合方法优化的算法

联邦学习算法通过聚合不同客户端参数共同训练一个全局模型,由于不同客户端数据是非独立同分布的,更新的模型参数可能存在很大不同,同时由于隐私保护的要求,服务器不能直接访问客户端数据,容易受到恶意攻击的影响,例如将使用错误标签更新的模型参数发送给服务器以误导模型更新方向。如何聚合来自不同客户端的数据以降低恶意攻击带来的影响,并提供一个针对不同客户端表现良好的全局模型是一个重要问题。联邦平均算法(federated averaging algorithm, FedAvg)[1]使用一种简单直白的权重聚合方法,将客户端内数据量与全体客户端总数据量的比值作为权重聚合不同客户端发送的参数。相较于联邦随机梯度下降算法(federated stochastic gradient descent algorithm,FedSGD)[1]每次使用客户端所有数据进行一轮梯度下降的方式,其采用的本地多轮更新的方式加快了模型收敛速度。联邦平均算法因为其简单有效的思想很快流行起来,但是其简单的加权聚合方式难以解决数据异构、易受攻击的问题。基于服务器动量的联邦平均算法(federated averaging with server momentum algorithm, FedAvgM)[21]通过引入动量的方法缓解数据异构对联邦平均算法的影响。参考文献[22]借鉴非联邦环境下的自适应优化器(自适应梯度(adaptive gradient, AdaGrad)[23]、自适应矩估计(a d a p t i ve moment estimation,Adam)[24]和YOGI[25])提出了联邦学习版本的自适应优化器(联邦自适应梯度(federatedadaptive gradient,FEDADAGRAD)算法[22]、联邦化YOGI[22]和联邦自适应矩估计(federated adaptive moment estimation, FEDADAM[22])算法),通过自适应优化器显著提高了模型在数据异构情况下的收敛速度。简单的聚合难以应对不同客户端的个性化需求,参考文献[26]和[27]提出两种分层聚合的方式,其中参考文献[27]提出了共享基础层加个性化层的个性化联邦学习模型算法(personalized federated training algorithm,FedPER),其中基础层由不同客户端共同训练,个性化层由本地数据训练,这种带有个性化层的方法可以有效减少数据异构带来的模型在不同客户端上表现差异的问题。参考文献[26]提出了联邦匹配平均算法(federated matched averaging algorithm,FedMA),该算法以分层方式构建共享全局模型。

 1.4.2      基于客户端优化的算法

由于不同客户端上的数据是非独立同分布的,且不同节点间互相不能交换数据,客户端在本地数据进行模型训练时,无法得知其他客户端的信息,模型的更新方向可能会受到本地数据分布的影响,导致各个客户端模型更新方向出现较大差异。利用全局模型的信息约束本地模型的更新,可以在增加模型个性化的同时避免模型间出现较大偏差。参考文献[28]进一步扩展联邦平均算法,提出了一种新的算法FedProx,它规定客户端有局部损失函数,进一步使用基于前一步权重的二次惩罚进行正则化,在数据异构的环境中显示出联邦平均算法的进步性,该方法受到连续迁移学习早期工作的影响,还具有很大的改进空间。参考文献[15]提出基于地球移动距离的联邦算法(federated earth mover’ distance,Fed-EMD),该算法将部分客户端生成的参数或服务器生成的模型共享给整个客户端,并通过创建一个在所有数据节点间共享的数据子集,减少数据不平衡的影响。但是这些解决方案需要很高的通信成本,且难以满足联邦学习的隐私保护要求。以往的目标是通过网络训练一个统一的全局模型[29],这种方法难以解决联邦学习中的统计问题。参考文献[30]提出了一种自适应个性化联邦学习算法(adaptive personalized federated learning algorithm,APFL),其通过推导全局模型和局部模型的一般边界找出最优混合参数,并提出了一种高效的通信方法,帮助客户端高效地学习个性化模型。联邦平均算法虽然简单、通信成本低,但是其受到数据不平衡的影响很大,且不同客户端在进行本地更新时无法了解到其他客户端的更新信息,可能会由于本地数据的异构性导致其更新方向与其他客户端产生漂移,利用正则化项可以很好地约束本地模型的更新方向。参考文献[31]提出了一种随机控制平均算法(stochastic controlled averaging algorithm,SCAFFOLD),其使用控制变量(方差减少)来纠正本地更新中的“客户端漂移问题”,还可以利用客户间的相似性,进一步降低所需的沟通成本。同样的,在参考文献[32]中,Ditto算法在不同客户端损失函数中引入正则化项,并通过正则化项前系数控制模型在个性化和鲁棒性间的平衡。参考文献[33]借鉴对比学习的思想提出模型对比联邦学习(model contrastive federated learning, MOON)算法,其不同于Ditto算法将本地模型与全局模型的欧氏距离作为正则化项以鼓励个性化模型向全局最优模型靠近,MOON算法利用模型表示之间的相似性纠正各个客户端的局部学习,将全局模型与本地模型表示间的对比损失作为正则化项约束本地模型的更新。

联邦学习算法分类见表2。

2 元学习介绍

2.1  元学习定义

很难给出元学习的确切形式化定义[34]。一般来说元学习就是学会去学习,希望训练一个通用的学习算法,该算法可以很好地适应新的任务,元学习研究系统如何通过经验提高效率,目标是了解学习本身如何根据学习领域灵活变动[35-36],元学习往往以小样本学习和对任务的快速适应作为切入点[37]。

图片

表2   联邦学习算法分类

人类可以通过几张动物的照片快速地学习到该动物的概念,这对应元学习的少镜头学习(few-shot learning,FSL)情景。人类甚至可以在没有图像的情况下,仅仅凭借描述就能认识到新的类别,这对应元学习的零镜头学习情景。元学习按照支持集每类样本的数量可以分为3类:单镜头学习(one-shot)、k镜头学习(k-shot)、零镜头学习(zero-shot)。

元学习一般训练过程如图2所示。首先在训练集上采样构建不同的任务

图片

 (由支持集Si和查询集Qi组成),模型在支持集Si上进行参数优化,得到对应该任务

图片

 的中间参数模型ϕi,然后在查询集Qi上使用模型ϕi 计算损失函数Li,并最小化不同任务上损失函数值的和训练一个基础模型ϕ。然后在测试集中,通过任务 

图片

 的支持集数据进行简单的几步梯度下降就可以得到新的模型ϕ′,以适应新的任务。最后在测试集中,利用任务

图片

中查询集测试模型ϕ′的表现。培养机器利用先前经验快速适应新任务的能力,就是让机器学会学习。

2.2   元学习分类

传统的机器学习目的在于让机器学会理解事物的异同以区分不同的事物,而不是学会识别没见过的事物。元学习的目的是教会机器如何利用先验知识快速学习新知识,快速掌握识别新物体的能力。根据训练数据有无标签,可以将元学习分为监督元学习和无监督元学习两种,如图3所示。

 2.2.1  无监督元学习方法

当训练数据没有标签时,无监督元学习[38-39]常采用一种显式的方式自动构建数据集,通过构建虚标签的方式,将无监督学习转换为监督学习。参考文献[38]为了解决训练数据无标签问题,提出了一个分阶段训练集群自动构造任务(clustering to automatically construct tasks, CACTUs)算法,其先在无标签训练数据上使用无监督训练方法学习一个特征表示器,然后通过聚类的方式在无标签数据上进行聚类划分,并生成伪标签。其通过伪标签构建元学习任务,在元学习任务上训练常规的监督元学习模型,如模型不可知元学习算法(model-agnostic meta-learning,MAML)[5]、原型网络算法(prototypical networks,ProtoNet)[40]等。与CACTUs算法分阶段训练的过程不同,参考文献[39]提出一种端到端的无监督元学习(unsupervised meta-learning with tasks constructed by random sampling and augmentation,UMTRA)算法,其通过数据增强的方式为每个图片生成一个增强数据,并将原数据作为支撑集,将增强图片作为查询集,构建N类单镜头(N-way-1-shot)任务。UMTRA在无标签数据上的分类准确度已经非常接近有标签数据集上MAML模型的分类准确度。

图片

图2   元学习一般训练过程

图片

图3   元学习分类

2.2.2     监督元学习方法

监督元学习旨在根据有限的数据信息,快速学习适应新任务的能力,根据算法分类,可以将监督元学习的方法分为以下5种:基于优化器的方法、基于记忆存储的方法、基于基础泛化模型的方法、基于度量学习的方法、基于数据增强的方法。

(1)基于优化器的方法

基于优化器的方法旨在通过学习一个更好的优化器加快学习过程。参考文献[41]中提出使用一个基于长短期记忆网络(long short term memory,LSTM)的元学习者来学习一个更新规则,这个更新规则可以被看成一种新的类似于但不同于梯度下降的优化算法。在参考文献[42]中RL2 (fast reinforcement learning via slow reinforcement learning)利用一个慢速调整的强化学习方法去训练一个快速调整的强化学习,达到学会学习的目标,即用一个强化学习去学习一个强化学习算法。参考文献[43]中提出了一种加快神经网络训练速度的方法LTL(learning t o learn by gradient descent by gradient descent)以实现快速学习,通过以往的神经网络学习的任务预测梯度,文章通过为梯度下降算法训练一个学习器以加快梯度下降算法的收敛速度。

(2)基于记忆存储的方法

先验知识对于后续任务具有重要作用,合理利用先验知识可以帮助模型快速适应新的任务。参考文献[44]提出了一种带有记忆增强神经网络的元学习(memoryaugmented neural network,MANN)算法,该算法利用外部存储器进行样本特征的保存,使用元学习算法改进单元的读取和写入方式,并采用错位匹配的方式避免在训练过程中记住样本的相应位置。权重的缓慢更新实现了网络的长期记忆功能,并利用外部存储实现短期记忆,最终实现元学习的快速训练。在参考文献[45]中,作者引入时间卷积网络访问之前的特征信息提出了一种元学习模型简单神经注意力学习器(simple neural attentive learning, SNAIL),使其可以在某个固定的时间内使用更加灵活的计算。通过时间卷积网络和注意力机制的结合,网络可以更加准确地在先前的信息中进行选择。

(3)基于基础泛化模型的方法

基于基础泛化模型的方法旨在学习一个可以快

图片

图片

图片

图片

图片

图4   MAML算法训练过程

图片

图5   原型网络示意图

监督元学习分类见表3。

3 联邦元学习介绍

3.1   联邦元学习定义

图片

3.2    联邦元学习算法分类

图片

图片

图片

表3   监督元学习分类

图片

 3.2.1     面向数据异构的联邦元学习算法

图片

图片

图片

图片

图6   联邦元学习算法分类

(3)Per-FedAvg算法

图片

(4)所示:

图片

(4)

图片

图片

图片

图片

图片

3.2.2    面向资源挑战的联邦元学习算法

图片

图片

图片

图片

3.2.3    面向隐私保护的联邦元学习算法

图片

图片

图片

图片

图片

3.3    联邦元学习应用

图片

图片

图片

表4   联邦元学习算法分类

图片

图片

图片

4 总结和展望

图片

图片

图片

图片

声明:公众号转载的文章及图片出于非商业性的教育和科研目的供大家参考和探讨,并不意味着支持其观点或证实其内容的真实性。版权归原作者所有,如转载稿涉及版权等问题,请立即联系我们删除。

“人工智能技术与咨询”  发布

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值