Communication-Efficient Learning of Deep Networks from Decentralized Data
原文来源:[Arxiv2017] Communication-Efficient Learning of Deep Networks from Decentralized Data
Abstract
Federated Learning : 训练数据分布在移动设备上,通过聚合本地计算的更新来学习共享模型
提出的Iterative model averaging
Introduction
Federated Learning 的 Learning tasks是被 loose federation 的参与者,即clients完成,由central server来提供协调
每个client都有一个local dataset,这个dataset是不会上传到server的
每个client对global model计算一个update,对这个update进行communication。这是focused collection集中收集或者data minimization数据最小化原则的直接应用
这些updates是针对特定model,当被应用后,就没有必要存储他们
优点:模型训练与直接访问原始数据的解耦。也可将供给面限制在设备,而不是设备和云,来降低风险
提出了Federated Averaging的算法,结合了本地的SGD,在每个Client和Server制行model averaging
Federated Learning的属性:
1)与数据中心提供的针对代理数据的培训相比,来自mobile devices的real-world的数据训练更有优势
2)数据都是privacy sensitive或者large in size,因此不要将其记录在data center
3)对于监督学习任务,数据的labels可以从与users的交互中推断出
Privacy
data center 拥有一个匿名的数据集,通过连接其他用户的数据危及隐私
而federated learning传输数据是一些model的最小的updates,隐私的强度也取决于更新的内容
aggregation algorithm可以在不识别元数据来源的情况下完成,因此updates可以在直接传输
Federated Optimization
clients的数据都是基于mobile devices的使用情况,特定用户的local dataset不会代表什么分布
一些users对于service或者app的使用会更多,导致local training data的数量是变化较大
参与optimization的用户数量远大于每个用户的示例数量
mobile devices通常是offline
Federated optimization需要考虑实际的问题:client的dataset随着数据的增添删除会变化
过程
固定K个clients,每个都有一个固定的local dataset
每一轮开始时,随机选中C个clients,server发送当前的global algorithm state给每个clients(当前global model的parameters)
每个选中的clients基于global state和local dataset进行本地计算,发送update到server
server把这些updates应用到global state,并重复上述过程
f
i
(
w
)
是在样例(
x
i
,
y
i
)上的基于
g
l
o
b
a
l
p
a
r
a
m
e
t
e
r
s
w
的预测的
l
o
s
s
值
f_i(w)是在样例(x_i,y_i)上的基于global \quad parameters\quad w的预测的loss值
fi(w)是在样例(xi,yi)上的基于globalparametersw的预测的loss值
使用额外的计算来减小通信的轮数
1)increased parallelism 增加更多的clients独立工作
2)increased computation on each client
Related work
通过分布式的iteratively averaging locallly training已经有进展
联邦学习通常不考虑datasets的unbalanced和non-IID
也可关注训练的深度网络,强调privacy的重要性,在每一轮通信中共享参数的子集
每个model在本地找到minimize loss的parameters,然后发送到server去average
The FederatedAveraging Algorithm
SGD应用于联邦优化问题,每轮通信中进行单个批处理的梯度计算
该方法效率高,但需要大量的训练才能得到好的model
FederatedSGD FedSGD: 每一轮选择C分之一的clients,计算这些clients拥有数据的gradient的 loss值
C控制了global batch 的size,C = 1 对应的是full-batch
每个client在本地使用当前的model使用local data做一步的gradient descent
然后server对于result model 做一次加权的average
因此,可以对每个client通过本地迭代update多次来更多的增大计算量。这也就是FederatedAveraging (FedAvg)
三个重要的参数:
C:clients的小数占比
E:每个client在每一轮重的训练的次数,本地的epochs
B:对于clients的updates而言的 local minibatch size
B是无穷时候,表示full local dataset 作为了一个single minibatch
对于一个E= 1 ,B= 无穷就是其中的一个极端
对于一个有nk个local examples的client,每轮local updates的数量是uk = E * nk/ B
Experimental Results
MNIST的digit recognition task
1)一个多层感知机,有2层的hidden layers,使用ReLu激活函数,用200个units,也就是MNIST 2NN
2)一个CNN有两个55的卷积层,第一个是32channels,第二个是64个,每个后面是22的max pooling,一个有512个units的全连接层,和ReLu的激活函数,以及最后的softmax的output layer
考虑了两种MNIST的数据的分布:
1)IID,数据是被shuffled,之后划分到100clients上,每个接收600examples
2)Non-IID,根据digit的label去sort,然后划分为200个碎片,每个shards的size是300,分配给每100个clients 2个shards
Language的modeling
a stacked character-level LSTM language model
在读入一行的character后,预测下一个character
以一系列的字符作为input,将每个character嵌入到一个学习到的8维的空间中
嵌入的字符之后会被一个2层的LSTM model处理,每层有256个nodes
之后第二层的LSTM的输出进入到一个softmax的output layer,每个character只有一个node
Increasing parallelism
C控制了多个clients的并行度
调整C来找到合适的并行
Increasing computation per client
固定C = 0.1
在每一轮给每个client加入更多的computation
要么减小E,增大B
增加更多的local updates可以很大程度降低通信开销
每个client的预期的updates的数量是u = E * nk / B
通过改变E和B来增大u是有效的,只要B足够大,就可以充分利用client硬件的并行性