一. 介绍
联邦学习中由于各个客户端上数据异构问题,导致全局训练模型无法适应每一个客户端的要求。作者通过利用客户端之间的共同代表来解决这个问题。具体来说,将数据异构的联邦学习问题视为并行学习任务,这些任务之间可能存在一些共同的结构,作者的目标是学习和利用这种共同的表示来提高每个客户端的模型质量,基于此提出了FedRep(联邦表示学习)。
FedRep: 联邦表示学习利用跨客户机的存储的所有数据,使用基于梯度的更新来学习全局低维表示。此外,使得每个客户端能够计算一个个性化的、低维的分类器,负责每个客户端的本地数据的唯一标识。
二. 问题定义
传统的联邦学习从n个客户端上优化下面目标:
min ( q 1 , . . . , q n ) ∈ Q n 1 n ∑ i = 1 n f i ( q i ) \min_{(q_1,...,q_n)\in\mathcal{Q_n}}\frac{1}{n}\sum_{i=1}^nf_i(q_i) (q1,...,qn)∈Qnminn1i=1∑nfi(qi)
其中
f
i
f_i
fi表示第i个客户端上的损失函数,
q
i
q_i
qi表示第i个客户端上的模型。但由于客户端上的数据较少,同时客户端数量庞大,客户端无法学习到一个很小损失的模型,因此联邦学习允许客户端之间进行参数交互。传统的方式是想让客户端学习到一个共同的模型,也就是
q
1
=
q
2
=
.
.
.
=
q
n
q_1=q_2=...=q_n
q1=q2=...=qn,但当客户端数据异构明显时,客户端的模型应该更接近于本地的数据。因此我们有必要去学习到一组
{
q
i
}
\{q_i\}
{qi}使得其满足于自身的数据。
学习一个共同的表示(Learning a Common Representation)。我们考虑一个全局的表示
ϕ
:
R
d
→
R
k
\phi:\mathbb{R}^d \to \mathbb{R}^k
ϕ:Rd→Rk,将数据映射到一个更低的维度k;客户端的特殊表示头:
R
k
→
Y
\mathbb{R}^k \to \mathcal{Y}
Rk→Y。根据此,第i个客户端上的模型是客户端上的局部参数和全局表示的组合:
q
i
(
x
)
=
(
h
i
∘
ϕ
)
(
x
)
q_i(x)=(h_i \circ\phi)(x)
qi(x)=(hi∘ϕ)(x)。值得注意的是,k远远小于d,也就是说每个客户端必须在本地学习的参数数量很少。我们根据新的内容重新改写我们的全局优化目标:
min ϕ ∈ Φ 1 n ∑ i = 1 n min h i ∈ H f i ( h i ∘ ϕ ) \min_{\phi \in \Phi}\frac{1}{n}\sum_{i=1}^n\min_{h_i\in\mathcal{H}}f_i({h_i} \circ\phi) ϕ∈Φminn1i=1∑nhi∈Hminfi(hi∘ϕ)
其中 Φ \Phi Φ为可行的表示类,而 H \mathcal{H} H为可行的头类。客户端使用所有客户的数据协同学习全局模型,同时使用自己的本地信息学习个性化的头部。
三. FedRep算法
算法思想如图所示:
服务器和客户端共同学习
ϕ
\phi
ϕ,客户端自己学习自己的参数头
h
h
h。
客户端更新: 在每一轮,被选中的客户端进行训练。这些客户端通过服务端来的
ϕ
i
\phi_i
ϕi进行更新自己的
h
i
h_i
hi,如下:
h i t , s = G R D ( f i ( h i t , s − 1 , ϕ t ) , h i t , s − 1 , α ) h_i^{t,s} = GRD(f_i(h_i^{t,s-1},\phi^t),h_i^{t,s-1},\alpha) hit,s=GRD(fi(hit,s−1,ϕt),hit,s−1,α)
GRD为一个梯度下降的优化表示,其意思为我们对参数h在f上使用一次梯度下降以 α \alpha α为步长进行更新。训练完 τ h \tau_h τh步更新h后,我们同样 ϕ \phi ϕ进行 τ ϕ \tau_\phi τϕ次更新,如下:
ϕ i t , s = G R D ( f i ( h i t , τ h , ϕ i t , s − 1 ) , ϕ i t , s − 1 , α ) \phi_i^{t,s}=GRD(f_i(h_i^{t,\tau_h},\phi_i^{t,s-1}),\phi_i^{t,s-1},\alpha) ϕit,s=GRD(fi(hit,τh,ϕit,s−1),ϕit,s−1,α)
服务端更新: 客户端完成更新后返回给服务端
ϕ
i
t
,
τ
ϕ
\phi_i^{t,\tau_\phi}
ϕit,τϕ,服务端聚合后后求平均。
具的算法如下图:
四. 代码详解
作者的代码点这里
相信这一篇文章大家应该理解起来不会有困难,就是分层进行一次处理即可。
首先,最主要关心的就是怎么分层。根据思路,我们需要分为rep层和head层,head为自己的参数。rep则是参与共享,在分层前,我们看一看网络:
class CNNCifar100(nn.Module):
def __init__(self, args):
super(CNNCifar100, self).__init__()
self.conv1 = nn.Conv2d(3, 64, 5)
self.pool = nn.MaxPool2d(2, 2)
self.drop = nn.Dropout(0.6)
self.conv2 = nn.Conv2d(64, 128, 5)
self.fc1 = nn.Linear(128 * 5 * 5, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, args.num_classes)
self.cls = args.num_classes
self.weight_keys = [['fc1.weight', 'fc1.bias'],
['fc2.weight', 'fc2.bias'],
['fc3.weight', 'fc3.bias'],
['conv2.weight', 'conv2.bias'],
['conv1.weight', 'conv1.bias'],
]
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 128 * 5 * 5)
x = F.relu(self.fc1(x))
x = self.drop((F.relu(self.fc2(x))))
x = self.fc3(x)
return F.log_softmax(x, dim=1)
一个很简单的CNN网络,其中我们把每一层名称储存下来,方便进行分层。
if args.alg == 'fedrep' or args.alg == 'fedper':
if 'cifar' in args.dataset:
w_glob_keys = [net_glob.weight_keys[i] for i in [0,1,3,4]]
elif 'mnist' in args.dataset:
w_glob_keys = [net_glob.weight_keys[i] for i in [0,1,2]]
elif 'sent140' in args.dataset:
w_glob_keys = [net_keys[i] for i in [0,1,2,3,4,5]]
else:
w_glob_keys = net_keys[:-2]
这里就是简要的分层操作,可以看到对于我们处理cifar100的话,rep层取得是 0 1 3 4,对应的也就是除了fc3的最后一层。因此是最后一层为head,其余为rep。
之后开始训练,训练首先对于客户端来说是获取服务端的参数rep再加上自己的参数head,代码为:
if args.alg != 'fedavg' and args.alg != 'prox':
for k in w_locals[idx].keys():
if k not in w_glob_keys:
w_local[k] = w_locals[idx][k]
其中w_glob_keys 就是rep的参数,w_local为所有的参数。
最后就是训练:
for iter in range(local_eps):
done = False
# for FedRep, 首先我们训练head固定rep,训练个几轮
if (iter < head_eps and self.args.alg == 'fedrep') or last:
for name, param in net.named_parameters():
if name in w_glob_keys:
param.requires_grad = False
else:
param.requires_grad = True
# 然后训练rep固定head
elif iter == head_eps and self.args.alg == 'fedrep' and not last:
for name, param in net.named_parameters():
if name in w_glob_keys:
param.requires_grad = True
else:
param.requires_grad = False