摘要
联邦学习是一种机器学习设置,其目标是训练高质量的集中模型,同时训练数据仍然分布在大量客户机上,每个客户机都具有不可靠且相对较慢的网络连接。我们考虑学习这种设置的算法,在每轮测试中,每个客户端都根据其本地数据独立计算对当前模型的更新,并将此更新与中央服务器通信,在中央服务器上聚合客户端更新以计算新的全局模型。在这种情况下,典型的客户是手机,通信效率是最重要的。
1介绍
随着数据集越来越大,模型越来越复杂,训练机器学习模型越来越需要在多台机器上分布模型参数的优化。现有的机器学习算法是为高度受控的环境(例如数据中心)设计的,在这种环境中,数据以平衡的、i.i.d.的方式分布在机器之间,并且具有高吞吐量的网络。
最近,联邦学习(以及相关的分散方法) (McMahan & Ramage)提出了另一种设置:一个共享的全球模型在来自参与设备联盟的中央服务器的协调下训练。参与其中的设备(客户机)通常数量很大,并且具有缓慢或不稳定的internet连接。当培训数据来自用户与移动应用程序的交互时,联邦学习的一个主要激励例子就出现了。
联合学习使移动电话能够协作学习共享的预测模型,同时将所有培训数据保存在设备上,从而将机器学习的能力与将数据存储在云中的需要分离开来。培训数据保存在用户的移动设备上,设备被用作对其本地数据执行计算的节点,以便更新全局数据模型。这超出了使用本地模型对移动设备进行预测的范围对设备进行模型训练。上述框架不同于传统的分布式机器学习,由于客户非常多,高度不平衡和非i.i.d。每个客户机上可用的数据,以及相对较差的网络连接。在本文中,我们将重点放在最后一个约束上,因为这些不可靠且不对称的连接对实际的联邦学习构成了特殊的挑战。
在本文中,我们提出了两种降低上行通信成本的方法:结构化更新,我们可以直接从有限的空间中学习更新,使用较少的变量,如低秩或随机掩码;还有草图更新,在那里我们学习完整的模型更新,然后使用量化、随机旋转和子采样的组合对其进行压缩,然后再将其发送到服务器。在卷积网络和递归网络上的实验表明,该方法可以将通信成本降低两个数量级。
为简单起见,我们考虑用于联合学习的同步算法,其中典型的轮回包括以下步骤:
1.选择现有客户端的子集,每个子集都下载当前模型。
2.子集中的每个客户端根据其本地数据计算更新的模型。
3.模型更新从选定的客户端发送到服务器。
4.服务器聚合这些模型(通常通过平均)来构建一个改进的全局模型。
上述框架的简单实现要求每个客户端在每轮中将完整模型(或完整模型更新)发送回服务器。 对于大型模型,由于多种因素,这一步骤很可能成为联合学习的瓶颈。 一个因素是Internet连接速度的不对称属性:上行链路通常比下行链路慢得多。 美国的平均宽带速度是下载55.0Mbps相对于上传18.9Mbps,一些互联网服务提供商更加不对称,例如Xfinity的下行速度为125Mbps而不是15Mbps。
此外,现有的模型压缩方案可以减少下载当前模型所需的带宽,并采用了加密协议,以确保在平均数百或数千个其他更新之前,无法检查单个客户端的更新,这进一步增加了需要上传的位数。
因此,研究可降低上行链路通信成本的方法非常重要。本文研究了两种通用方法:
•结构化更新,我们可以直接从受限空间学习更新,该更新可以使用较少数量的变量进行参数化。
•草绘的更新,我们在其中学习完整的模型更新,然后将其压缩后再发送到服务器。
可以结合使用第2节和第3节中详细介绍的这些方法,例如,首先学习结构化更新并将其草绘; 不过,我们在这项工作中没有对这种组合进行实验。
在下文中,我们正式描述问题。== 联邦学习的目标是从存储在大量客户端的数据中学习一个参数包含在一个实矩阵w∈R d1×d2中的模型。==该模型具有包含在实矩阵W∈R d 1×d 2中的参数。
我们首先提供联合学习的纯沟通版本。 在t≥0的回合中,服务器将当前模型Wt分配给n t个客户端的子集St。 这些客户端根据其本地数据独立更新模型。 让更新的局部模型为Wt1 Wt2…,这样客户端i的更新可以写为。
这些更新可能是在客户端上计算的单个梯度,但通常是更复杂的计算的结果,
例如,对客户端的本地数据集采取了多个随机梯度下降(SGD)步骤。 在任何情况下,每个选定的客户端都将更新发送回服务器,在服务器上通过聚合所有客户端更新计算全局更新:
服务器选择学习速率ηt。 为简单起见,我们选择ηt = 1。
==在第4节中,我们描述了神经网络的联合学习,==其中我们使用单独的2D矩阵W来表示每一层的参数。 我们假设W右乘,即d1和d2分别代表输出和输入尺寸。
注意,完全连接的层的参数自然地表示为2D矩阵。 但是,卷积层的内核是形状为#input×width×height×#output 的4D张量。 在这种情况下,W从内核重塑为形状(#input×width×height)×#output。
概述和总结。
提高联合学习的通信效率的目标是减少向服务器发送H的成本,同时学习存储在大量设备上的数据,这些设备的internet连接和计算可用性有限。我们提出了两类通用方法,即结构化更新和草图更新。在“实验”部分,我们评估了这些方法在训练深度神经网络中的效果。
在CIFAR数据的模拟实验中,我们研究了这些技术对联合平均算法的收敛性的影响。收敛速度略有下降的情况下,我们就能将通信的数据总量减少两个数量级。这使我们可以通过全卷积模型获得良好的预测精度,而总的通信量却少于原始CIFAR数据的大小。在针对用户分区的文本数据进行的较大的实际实验中,我们证明了我们甚至可以只使用每个用户的数据,就能够有效地训练递归神经网络进行下一个单词的预测。最后,我们注意到我们获得了最佳结果,包括使用结构化随机轮换对更新进行预处理。此步骤的实际效用在我们的设置中是独一无二的,因为在典型的SGD并行实现中,应用随机轮转的成本将占主导地位,但与联邦学习中的本地培训相比,这是微不足道的。
2 结构化更新
通信有效更新的第一类型将更新H限制为具有预定结构。 本文考虑了两种类型的结构:低秩和随机掩码。 需要强调的是,我们直接训练该结构的更新,而不是像第3节中讨论的那样,将一般的更新与特定结构的对象相匹配/绘制草图。
低秩
我们将对本地模型H的每次更新强制为最多k个秩的低秩矩阵,其中k是一个固定数。为此,我们将H表示为两个矩阵的乘积。在后续计算中,我们随机生成A并在局部训练过程中考虑一个常数,并且仅对B进行优化。注意,在实际的实现中,A可以在这种情况下情况以随机种子的形式进行压缩,只需要将经过训练的B发送到服务器即可。这种方法立即节省了通信中的d 1 / k因子。
我们在每个回合中为每个客户重新生成矩阵A。我们还尝试了固定B和培训A,以及同时培训B;两者都表现不佳。对这一现象的直观解释如下。我们可以把B解释为一个投影矩阵,把A解释为一个重构矩阵。修复A并对B进行优化,类似于问“给定随机重建,能够恢复最多信息的投影是什么?”在这种情况下,如果重构是满秩的,则存在恢复由前k个特征向量张成的空间的投影。然而,如果我们随机固定投影并搜索重建,我们可能会很不幸,重要的子空间可能已经被投影出去了,这意味着没有重建可以做得很好,或者很难找到。
随机掩码
遵循预定义的随机稀疏模式(即随机掩码),我们将更新H限制为稀疏矩阵。 模式在每轮中重新生成,并独立地为每个客户机生成。 类似于低秩,稀疏模式可以由随机种子完全指定,因此只需要将H的非零项的值连同种子一起发送。
3草图更新
第二种解决通信成本的更新(我们称之为草绘),首先在本地训练期间计算完整H,没有任何限制,然后以(有损)压缩形式对更新进行近似或编码,然后将其发送到服务器。 服务器在进行聚合之前对更新进行解码。为了进行草图绘制,我们尝试了多种工具,这些工具可以相互兼容并可以共同使用:
子采样
每个客户端只发送矩阵H,而不是发送H,每个客户端只有通信矩阵ˆH是由H一个随机的子集(比例)形成的。然后,服务器对子采样的更新取平均值,产生全局更新ˆH。 可以这样做,以使采样更新的平均值是真实平均值的无偏估计量:E [ˆH] = H。 与随机掩码结构化更新类似,掩码在每个回合中对每个客户端独立随机化,并且掩码本身可以存储为同步种子。
概率量化
压缩更新的另一种方法是量化权重。
我们首先描述将每个标量量化为一位的算法。
h = (h 1 ,…,h d 1 ×d 2 ) = vec(H ), hmax= maxj (hj ), hmin = minj (hj )
由^h表示的h的压缩更新如下生成:
很容易证明〜h是h的无偏估计量。 与4字节浮点数相比,此方法可提供32倍的压缩率。 分析了这种压缩方案产生的误差。对于每个标量,也可以将其推广到1位以上。 对于b位量化,我们首先将[hmin,hmax]平均分为2 ^ b个区间。 假设hi落在h’和h’‘的范围内。 通过分别用h’和h’'代替上述方程式的hmin和hmax来进行量化。 然后,参数b允许在精度和通信成本之间进行平衡的简单方法。Alistarh等人最近提出了另一种量化方法,该方法也通过减少通信而求平均向量。(2016)。 可以在量化更新设置中类似地分析增量,随机和分布式优化算法(Rabbat&Nowak,2005; Golovin等,2013; Gamal&Lai,2016)。
结构化的旋转随机量化
通过结构随机旋转改进量化。当尺度在不同维度上近似相等时,上面的1位和多位量化方法效果最好。
例如,当max = 1和min = -1且大多数值为0时,1位量化将导致较大的误差。我们注意到,在量化之前对h进行随机旋转(将h与随机正交矩阵相乘)可以解决此问题。在该工作中,表明结构化随机旋转可以将量化误差降低O(d / logd)倍,其中d是h的维数。我们将在下一部分中显示其实用性。
在解码阶段,服务器需要在聚合所有更新之前执行反向旋转。请注意,在实践中,h的维数很容易高达d = 10^6或更大,并且在计算上禁止生成(O d ^3)和应用(O d ^2)一般旋转矩阵。我们使用一种结构化旋转矩阵,该矩阵是Walsh-Hadamard矩阵和二进制对角线矩阵的乘积。这降低了生成和应用矩阵到O(d)和O(dlogd)的计算复杂度,这与联合学习中的本地训练相比可以忽略不计。
4实验
我们使用联邦学习进行了实验,以训练针对两个不同任务的深层神经网络。首先,我们使用卷积网络和人工分块数据集对CIFAR-10图像分类任务进行了实验,并详细研究了我们提出的算法的性质。其次,我们使用更现实的联邦学习场景——公共Reddit帖子数据,来训练一个循环网络来预测下一个单词。
Reddit数据集对于模拟联邦学习实验特别有用,因为它提供了自然的每个用户数据分区(根据文章作者)。这包括在实际执行中可能出现的许多特征。例如,许多用户拥有相对较少的数据点,大多数用户使用的单词围绕特定用户感兴趣的特定主题聚集。
在我们所有的实验中,我们采用联合平均算法,该算法大大减少了训练一个好的模型所需的通信次数。 但是,我们希望我们的技术在应用于同步分布式SGD时将显示出类似的通信成本降低。
对于联合平均,在每个回合中,我们随机地均匀选择多个客户端,每个客户端在其本地数据集上执行几个时期的SGD,学习率为η。
对于结构化更新,SGD被限制为仅在受限空间中进行更新,也就是说,仅B项用于低级别更新,而无掩码项用于随机掩码技术。 从这个更新的模型,我们计算每个层H的更新。 在所有情况下,我们都以各种学习率选择来运行实验,并报告最佳结果。
4.1 CIFAR-10数据集上的卷积模型
在本节中,我们使用CIFAR-10数据集来研究作为联合平均算法一部分的拟议方法的属性。
CIFAR-10数据集中有50000个训练示例,我们将其随机分为100个客户端,每个客户端包含500个训练示例。 我们使用的模型架构是全卷积模型,该模型来自被称为“模型C”的模型,总共有10 ^ 6个以上的参数。 尽管此模型不是最新的模型,但它足以满足我们的需求,因为我们的目标是评估我们的压缩方法,而不是在此任务上获得最佳的准确性。
该模型有9个卷积层,其中第一个和最后一个具有比其他卷积少得多的参数。 因此,在整个部分中,当我们尝试减小单个更新的大小时,我们仅压缩内部的7层,每个层具有相同的参数3。 对于所有方法,我们都用关键字“模式”表示。 对于低秩更新 “模式= 25%”是指将更新的秩设置为全层转换的秩的1/4,对于随机蒙版或草绘,这是指除25%以外的所有参数都为零出。
在第一个实验(总结于图1中)中,我们比较了第2节中介绍的两种:结构化更新、随机掩码更新。 主要信息是,随着我们减小更新的大小,随机掩码的性能明显优于低秩。 特别地,当以回合数来测量时,随机掩码的收敛速度似乎基本上不受影响。 因此,如右列所示,如果目标是仅使上传大小最小化(upload size),那么减小更新大小的版本无疑是赢家。
图1:使用CIFAR数据进行结构化的更新以缩小各种模式的尺寸。
图2中,我们比较了结构化更新和草绘更新的性能,没有进行任何量化。 由于在上文中,结构化随机掩码更新的执行效果更好,因此为清楚起见,我们将低秩更新忽略不计。 我们将其与草绘的更新的性能进行比较,在第3节中介绍了是否使用随机旋转对更新进行预处理以及是否进行预处理,以及两种不同的模式。 我们用“ HD”表示随机Hadamard旋转,而用“ I”表示无旋转。
图2:对CIFAR数据进行结构化随机掩码更新和不进行量化的草图更新的比较
图3:比较更新的草图,将更新与CIFAR数据的旋转,量化和二次采样相结合。
直观的期望是,直接学习结构化随机掩码更新要比学习非结构化更新好,后者被素描为用相同数量的参数表示。 这是因为通过草绘,我们丢弃了训练中获得的一些信息。 通过草绘更新,我们应该收敛到稍低的精度这一事实在理论上得到了支持,因为草绘更新会增加收敛性分析中直接出现的方差。 我们在使用结构化随机掩码更新时看到了这种行为,最终我们可以收敛到稍微更高的精度。 但是,我们还看到,通过草绘更新,我们能够稍快地获得适度的准确性(例如85%)。
在对CIFAR数据的最后一次实验中,我们重点研究了第3节中介绍的所有三个元素的相互作用——子采样、量化和随机旋转。请注意,所有这些工具的组合将实现更高的压缩率。图3中的每对图都关注于特定的模式(子采样),并且在每对图中,我们用量化中使用的不同比特来绘制性能图(有无随机旋转)。我们在所有的图中都可以看到,随机旋转提高了性能。一般来说,没有旋转的情况下,算法的性能不太稳定,特别是在量化比特数较少和模式较小的情况下。
为了强调节省通信的潜力,请注意,通过随机旋转进行预处理,绘制出除了6.25%的更新元素之外的所有元素并使用2位进行量化,我们在收敛性上只有很小的下降,而节省了256倍 就表示各个层的更新所需的位而言。 最后,如果我们希望最大程度地减少上传的数据量,则可以获得一定程度的准确性,例如85%,而总的交流量不到上传原始数据所需的一半。
4.2 LSTM对REDDIT数据的下一个单词预测
我们构建了用于模拟联邦学习的数据集,该数据基于包含Reddit上公开可用的帖子/评论的数据(谷歌BigQuery)。就我们的目的而言,数据库中的每个帖子都由一个作者键控,因此我们可以根据这些键对数据进行分组,假设每个作者有一个客户机设备。有些作者有非常多的帖子,但是在每轮FedAvg(联邦平均)中,每个用户最多处理32000个令牌。我们省略了少于1600个令牌的作者,因为在模拟中每个客户机都有固定的开销,而且数据很少的用户对培训贡献不大。这就剩下763430个用户的数据集,每个用户平均拥有24791个令牌。为了进行评估,我们使用一个相对较小的测试集,该测试集包含75122个标记,这些标记由随机保留的帖子组成。
基于此数据,我们建立了LSTM下一词预测模型该模型被训练用于预测给定当前单词的下一个单词,以及从前一个时间步传递的状态向量。 该模型工作如下:通过在10017个单词(令牌)的字典中查找单词,将单词st映射到一个嵌入向量et(∈R^ 96)然后,et与模型在前一个时间步st1(∈R^256)中发出的状态组成一个新的状态向量st和一个“输出嵌入”Ot(∈R ^96)
在通过softmax归一化计算整个词汇表上的概率分布之前,通过内积对输出词表的嵌入项进行评分。与其他标准语言模型一样,我们将每个输入序列视为以隐式“BOS”(序列开始)标记开始((令牌开头),以隐式“EOS”(序列结束)标记结束(令牌结尾)。与标准的LSTM语言模型不同,我们的模型对嵌入层和softmax层使用相同的学习嵌入。这使得模型的大小减少了约40%,而模型的质量却略有下降, 这对于移动应用程序来说是一个有利的折衷。与许多标准LSTM RNN方法不同的另一个变化是,我们 训练了这些模型以将单词嵌入限制为具有固定的L2规范1.0,发现可以缩短收敛时间。该模型总共有1.35M个参数。
为了减小更新的大小,我们绘制了所有模型变量的草图,但一些小变量(例如偏差)消耗的内存少于0.01%。 我们使用AccuracyTop1进行评估,即模型赋予最高概率的单词是正确的。即使模型中预测为“未知”,即使词典中没有下一个真正的单词,我们也总是将其视为错误。
在图4中,我们对Reddit数据运行联邦平均算法,并使用各种参数指定草绘。在每个迭代中,我们随机抽样50个用户,这些用户根据本地可用的数据计算更新,并绘制草图,然后对所有更新进行平均。在每个轮中对10、20和100个客户进行抽样的实验提供了类似的结论,如下所示。
在所有的图中,我们将这三个组件组合起来,以绘制第3节中介绍的更新。首先,我们应用随机旋转对本地更新进行预处理。此外,“草图分数”设置为0.1或1,表示更新的元素被子采样的分数。
图4:草图更新的比较,在Reddit数据上训练一个循环的模型,每轮随机抽取50个客户。
在左列中,我们根据算法的迭代次数来绘制这个图。 首先,我们可以看到,随机旋转的预处理效果具有显着的正效应,尤其是在量化位数较少的情况下。 有趣的是,对于所有选择的次采样率,将量化为2位的随机Hadamard变换不会造成性能损失。 要强调的一个重要指标是图中显示的回合数为2000。由于我们每回合对50个用户进行采样,所以这个实验不会接触到大多数用户的数据。这进一步加强了在实际环境中应用联邦学习而不影响用户体验的说法。
在右列中,我们将相同的数据与客户机需要与服务器通信的兆字节总数进行比较。从这些图中可以明显看出,如果需要一个主要最小化该指标的工具(如果主要需要最小化这个度量),那么我们提出的技术将非常有效。 当然,这些目标都不是我们在实际应用中要优化的目标。 尽管如此,鉴于当前缺乏大规模部署联合学习固有的问题的经验,我们认为这些是在实际应用中有用的代理。
图5:每轮培训中使用的客户数量的影响。
最后,在图5中,我们研究了单轮使用的客户数量对收敛性的影响。 我们对固定数量的回合(500和2500)运行联邦平均算法,每回合使用不同数量的客户端,将更新量化为1位,并绘制得出的精度。 我们看到,每轮有足够数量的客户(在本例中是1024个),我们可以将子采样元素的比例降低到1%,与10%相比,精度只会有微小的下降。在联邦设置中,这是一个重要且实用的权衡:可以在每轮中选择更多的客户机,同时使每个客户端之间的通信更少(例如,更主动的子采样),并获得与使用较少客户端相同的精度,但是每个客户端它们之间通信更多。 当有许多客户端可用时,前者可能更可取,但是每个客户端的上传带宽非常有限-这在实践中很常见。