Exploiting Shared Representations for Personalized Federated Learning 论文笔记+代码解读

论文地址点这里

一. 介绍

联邦学习中由于各个客户端上数据异构问题,导致全局训练模型无法适应每一个客户端的要求。作者通过利用客户端之间的共同代表来解决这个问题。具体来说,将数据异构的联邦学习问题视为并行学习任务,这些任务之间可能存在一些共同的结构,作者的目标是学习和利用这种共同的表示来提高每个客户端的模型质量,基于此提出了FedRep(联邦表示学习)。
FedRep: 联邦表示学习利用跨客户机的存储的所有数据,使用基于梯度的更新来学习全局低维表示。此外,使得每个客户端能够计算一个个性化的、低维的分类器,负责每个客户端的本地数据的唯一标识。

二. 问题定义

传统的联邦学习从n个客户端上优化下面目标:

min ⁡ ( q 1 , . . . , q n ) ∈ Q n 1 n ∑ i = 1 n f i ( q i ) \min_{(q_1,...,q_n)\in\mathcal{Q_n}}\frac{1}{n}\sum_{i=1}^nf_i(q_i) (q1,...,qn)Qnminn1i=1nfi(qi)

其中 f i f_i fi表示第i个客户端上的损失函数, q i q_i qi表示第i个客户端上的模型。但由于客户端上的数据较少,同时客户端数量庞大,客户端无法学习到一个很小损失的模型,因此联邦学习允许客户端之间进行参数交互。传统的方式是想让客户端学习到一个共同的模型,也就是 q 1 = q 2 = . . . = q n q_1=q_2=...=q_n q1=q2=...=qn,但当客户端数据异构明显时,客户端的模型应该更接近于本地的数据。因此我们有必要去学习到一组 { q i } \{q_i\} {qi}使得其满足于自身的数据。
学习一个共同的表示(Learning a Common Representation)。我们考虑一个全局的表示 ϕ : R d → R k \phi:\mathbb{R}^d \to \mathbb{R}^k ϕ:RdRk,将数据映射到一个更低的维度k;客户端的特殊表示头: R k → Y \mathbb{R}^k \to \mathcal{Y} RkY。根据此,第i个客户端上的模型是客户端上的局部参数和全局表示的组合: q i ( x ) = ( h i ∘ ϕ ) ( x ) q_i(x)=(h_i \circ\phi)(x) qi(x)=(hiϕ)(x)。值得注意的是,k远远小于d,也就是说每个客户端必须在本地学习的参数数量很少。我们根据新的内容重新改写我们的全局优化目标:

min ⁡ ϕ ∈ Φ 1 n ∑ i = 1 n min ⁡ h i ∈ H f i ( h i ∘ ϕ ) \min_{\phi \in \Phi}\frac{1}{n}\sum_{i=1}^n\min_{h_i\in\mathcal{H}}f_i({h_i} \circ\phi) ϕΦminn1i=1nhiHminfi(hiϕ)

其中 Φ \Phi Φ为可行的表示类,而 H \mathcal{H} H为可行的头类。客户端使用所有客户的数据协同学习全局模型,同时使用自己的本地信息学习个性化的头部。

三. FedRep算法

算法思想如图所示:
在这里插入图片描述
服务器和客户端共同学习 ϕ \phi ϕ,客户端自己学习自己的参数头 h h h
客户端更新: 在每一轮,被选中的客户端进行训练。这些客户端通过服务端来的 ϕ i \phi_i ϕi进行更新自己的 h i h_i hi,如下:

h i t , s = G R D ( f i ( h i t , s − 1 , ϕ t ) , h i t , s − 1 , α ) h_i^{t,s} = GRD(f_i(h_i^{t,s-1},\phi^t),h_i^{t,s-1},\alpha) hit,s=GRD(fi(hit,s1,ϕt),hit,s1,α)

GRD为一个梯度下降的优化表示,其意思为我们对参数h在f上使用一次梯度下降以 α \alpha α为步长进行更新。训练完 τ h \tau_h τh步更新h后,我们同样 ϕ \phi ϕ进行 τ ϕ \tau_\phi τϕ次更新,如下:

ϕ i t , s = G R D ( f i ( h i t , τ h , ϕ i t , s − 1 ) , ϕ i t , s − 1 , α ) \phi_i^{t,s}=GRD(f_i(h_i^{t,\tau_h},\phi_i^{t,s-1}),\phi_i^{t,s-1},\alpha) ϕit,s=GRD(fi(hit,τh,ϕit,s1),ϕit,s1,α)

服务端更新: 客户端完成更新后返回给服务端 ϕ i t , τ ϕ \phi_i^{t,\tau_\phi} ϕit,τϕ,服务端聚合后后求平均。
具的算法如下图:
在这里插入图片描述

四. 代码详解

作者的代码点这里
相信这一篇文章大家应该理解起来不会有困难,就是分层进行一次处理即可。
首先,最主要关心的就是怎么分层。根据思路,我们需要分为rep层和head层,head为自己的参数。rep则是参与共享,在分层前,我们看一看网络:

class CNNCifar100(nn.Module):
    def __init__(self, args):
        super(CNNCifar100, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.drop = nn.Dropout(0.6)
        self.conv2 = nn.Conv2d(64, 128, 5)
        self.fc1 = nn.Linear(128 * 5 * 5, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, args.num_classes)
        self.cls = args.num_classes

        self.weight_keys = [['fc1.weight', 'fc1.bias'],
                            ['fc2.weight', 'fc2.bias'],
                            ['fc3.weight', 'fc3.bias'],
                            ['conv2.weight', 'conv2.bias'],
                            ['conv1.weight', 'conv1.bias'],
                            ]

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 128 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = self.drop((F.relu(self.fc2(x))))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

一个很简单的CNN网络,其中我们把每一层名称储存下来,方便进行分层。

if args.alg == 'fedrep' or args.alg == 'fedper':
    if 'cifar' in  args.dataset:
        w_glob_keys = [net_glob.weight_keys[i] for i in [0,1,3,4]]
    elif 'mnist' in args.dataset:
        w_glob_keys = [net_glob.weight_keys[i] for i in [0,1,2]]
    elif 'sent140' in args.dataset:
        w_glob_keys = [net_keys[i] for i in [0,1,2,3,4,5]]
    else:
        w_glob_keys = net_keys[:-2]

这里就是简要的分层操作,可以看到对于我们处理cifar100的话,rep层取得是 0 1 3 4,对应的也就是除了fc3的最后一层。因此是最后一层为head,其余为rep。
之后开始训练,训练首先对于客户端来说是获取服务端的参数rep再加上自己的参数head,代码为:

if args.alg != 'fedavg' and args.alg != 'prox':
    for k in w_locals[idx].keys():
        if k not in w_glob_keys:
            w_local[k] = w_locals[idx][k]

其中w_glob_keys 就是rep的参数,w_local为所有的参数。
最后就是训练:

for iter in range(local_eps):
    done = False

    # for FedRep, 首先我们训练head固定rep,训练个几轮
    if (iter < head_eps and self.args.alg == 'fedrep') or last:
        for name, param in net.named_parameters():
            if name in w_glob_keys:
                param.requires_grad = False
            else:
                param.requires_grad = True
    
    # 然后训练rep固定head
    elif iter == head_eps and self.args.alg == 'fedrep' and not last:
        for name, param in net.named_parameters():
            if name in w_glob_keys:
                param.requires_grad = True
            else:
                param.requires_grad = False
  • 5
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值