Resource-aware Federated Learning using Knowledge Extraction and Multi-model Fusion论文阅读+代码解析

这段时间都在更新持续学习,今天在网上找到一篇关于联邦学习中不同模型的融合问题的论文,在这里分享给大家,论文地址点这里

一. 介绍

联邦学习(FL)已成为一种新的机器学习范式,用于分布式客户端参与集中式模型的协作训练。FL将模型异步培训引入边缘,设备(如手机和物联网设备)提取关于私有敏感培训数据的知识,然后将学习到的模型上传到云端进行聚合。FL在本地存储用户数据,并限制从云服务器直接访问;因此,这种模式不仅增强了隐私保护,而且引入了一些固有的优点,包括模型准确性、成本效率和多样性。随着当今机器学习模型对数据的巨大需求和人工智能的社会考量(安全性和隐私性)联合学习在权衡这一权衡方面具有巨大的潜力和作用。
常规的联邦学习对所有的模型权重进行加权平均,但是受到边缘设备数据异构的情况,可能会产生不公平、无效的全局模型,并且无法有效地部署。为了克服先前FL方法的上述局限性,本文提出了使用知识提取和多模型融合(FedKEMF)执行资源感知联邦学习的想法,如下图所示:
在这里插入图片描述

二. 方法

2.1 使用深度互学习网络进行知识提取

深度互学习的关键是训练多个神经网络同时最小化网络输出的Kullback Leibler。换句话说,KL散度评估了两种分布的相似性。通过最小化神经网络之间的KL发散,使得它们可以相互学习知识。因此,为了从本地模型中提取知识,我们在本地引入了一个知识网络(与本地模型相比是一个较小的网络),并使用深度互学习同时优化知识网络和本地模型。
为了直观地解释知识提取过程,我们在图像分类任务中解释了它。形式上,在边缘客户端中,我们有一个本地模型 θ \theta θ和一个知识网络(微型网络) θ g \theta_g θg。首先,更新局部模型 θ \theta θ:对于任何一批输入数据 x x x,我们使用交叉熵损失去计算预测值和真实标签的损失:
L c = − ∑ i = 1 N y T log ⁡ ( σ ( θ ( x i ) ) ) L_c=-\sum_{i=1}^N y^T \log \left(\sigma\left(\theta\left(x_i\right)\right)\right) Lc=i=1NyTlog(σ(θ(xi)))
其中 N N N表示为mini-batch的大小, σ \sigma σ为softmax函数。
然后我们计算 θ ( x ) \theta(x) θ(x) θ g ( x ) \theta_g(x) θg(x)之间的KL距离:
D K L ( θ g ∥ θ ) = ∑ i = 1 N σ ( θ g ( x ) T ) log ⁡ ( σ ( θ g ( x ) ) σ ( θ ( x ) ) ) D_{K L}\left(\theta_g \| \theta\right)=\sum_{i=1}^N \sigma\left(\theta_g(x)^T\right) \log \left(\frac{\sigma\left(\theta_g(x)\right)}{\sigma(\theta(x))}\right) DKL(θgθ)=i=1Nσ(θg(x)T)log(σ(θ(x))σ(θg(x)))
那么更新 θ \theta θ的总体损失如下:
L θ = L c + D K L ( θ g ∥ θ ) L_\theta=L_c+D_{K L}\left(\theta_g \| \theta\right) Lθ=Lc+DKL(θgθ)
下面是本地更新的算法:
在这里插入图片描述

2.2 多模型知识融合

在FedKEMF中,我们提供了两种模型融合方法,用于边缘知识的服务器融合。第一个类似于传统的FL,因为我们合计了重量。其次,受(Lin et al.2020)启发,我们将所有收到的客户模型进行集成,并将集成知识提取到全球知识网络中。在本节和实验中,我们主要关注整合客户的知识。然而,FedKEMF也可以使用传统的融合方法。
定义集成模型为 Θ = { θ g k } k ∈ S \Theta=\{\theta^k_g\}_{k\in S} Θ={θgk}kS,其中 θ g k \theta^k_g θgk表示为第 k k k个客户端的知识网络, S S S为和当前服务器通信的客户端集合。然后,通过使用服务器中的未标记数据、生成数据或公共数据,我们将集合 Θ \Theta Θ的知识蒸馏到一个全局知识网络 θ g \theta_g θg,如下:
L d = D K L ( Θ , θ g ) L_d=D_{K L}\left(\Theta, \theta_g\right) Ld=DKL(Θ,θg)
具体算法描述如下:
在这里插入图片描述

2.3 知识集成

在FedKEMF中,我们研究了三种集成策略,即最大logits、平均logits和多数投票。由于在实际应用中,采用最大对数作为集成策略效果最好。对于给定的输入实例 x x x,集成模型的计算公式如下:
Θ ( x ) =  Ensemble  Max ⁡ ( { θ g k ( x ) } k ∈ S ) \Theta(x)=\text { Ensemble }_{\operatorname{Max}}\left(\left\{\theta_g^k(x)\right\}_{k \in S}\right) Θ(x)= Ensemble Max({θgk(x)}kS)

三. 代码解析

代码链接点这里
在这篇代码的工作中,作者在resnet20、32、44三种模型中随机分配给不同的客户端:

def init_fl(n_parties,model_name, args):
    if args.env == 'multi-model':
        nets = {net_i: None for net_i in range(n_parties)}
        net_name = ['resnet20','resnet32','resnet44']
        for net_i in range(n_parties):
            import random
            net_type = net_name[random.randint(0, 2)]
            net = resnet.__dict__[net_type]()
            if args.ckpt_path is not None:
                # path = os.path.join(data_root, "pretrained_models",'resnet56-4bfd9763.th')
                checkpoint = torch.load(args.ckpt_path, map_location=args.device)
                sd = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
                net.load_state_dict(sd)

            net = torch.nn.DataParallel(net)
            nets[net_i] = net

        return nets, None, None
    else:
        return init_nets(n_parties,model_name, args)

之后每个客户端的知识模型使用的是一样的,这里不再赘述。
在训练中,首先是各个客户端的训练:

def local_update(nets:dict, g_k:nn.Module, selected, args, net_dataidx_map, test_dl_global,logger,lr = 0.01, device="cpu"):
    avg_acc = 0.0
    avg_kacc = 0.0
    k_nets = []
    for net_id, net in nets.items():
        if net_id not in selected:
            continue
        dataidxs = net_dataidx_map[net_id]

        noise_level = args.noise
        if net_id == args.n_parties - 1:
            noise_level = 0

        if args.noise_type == 'space':
            train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level, net_id, args.n_parties-1)
        else:
            noise_level = args.noise / (args.n_parties - 1) * net_id
            train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level)
        n_epoch = args.epochs



        logger.info("Training network %s. n_training: %d" % (str(net_id), len(dataidxs)))
        # move the model to cuda device:
        net.to(device)

        kd_agent = Distiller(lr=lr,epochs=n_epoch,device=device)

        nets_cohort = []
        optimizers = []

        nets_cohort.append(net)
        net_optimizer = optim.SGD(net.parameters(), lr)
        optimizers.append(net_optimizer)

        l_g_k = copy.deepcopy(g_k)
        k_nets.append(l_g_k)
        nets_cohort.append(l_g_k)
        gk_optimizer = optim.SGD(l_g_k.parameters(), lr)
        optimizers.append(gk_optimizer)



        #### if use the local test_dl (also need modify get_dataloader function):
        '''
        nets_cohort, lr_ = kd_agent.mutual_kd(nets_cohort,train_dl_local, test_dl_local,optimizers,s_save_path=None)

        acc =compute_accuracy(nets_cohort[0],test_dl_local,device=device)
        print("net %d final test acc %f" % (net_id, acc))
        logger.info("net %d final test acc %f" % (net_id, acc))
        acc_k = compute_accuracy(nets_cohort[1],test_dl_local,device=device)
        print("net %d knowledge network test acc %f" % (net_id, acc_k))
        logger.info("net %d knowledge network test acc %f" % (net_id, acc_k))
        '''


        #### if use the global test_dl:

        nets_cohort, lr_ = kd_agent.mutual_kd(nets_cohort, train_dl_local, test_dl_global, optimizers, s_save_path=None)

        acc = compute_accuracy(nets_cohort[0], test_dl_global, device=device)
        print("net %d final test acc %f" % (net_id, acc))
        logger.info("net %d final test acc %f" % (net_id, acc))
        acc_k = compute_accuracy(nets_cohort[1], test_dl_global, device=device)
        print("net %d knowledge network test acc %f" % (net_id, acc_k))
        logger.info("net %d knowledge network test acc %f" % (net_id, acc_k))


        avg_acc += acc
        avg_kacc += acc_k

    avg_acc /= len(selected)
    avg_kacc /= len(selected)
    logger.info("avg test acc after local update %f" % avg_acc)
    logger.info("avg knowledge test acc after local update %f" % avg_kacc)

    nets_list = list(nets.values())
    return nets_list,k_nets,lr_

    # raise NotImplementedError

可以发现作者将本地模型和知识模型一起放进了nets_cohort列表中,传给Deep Mutual 模块进行学习,这里如下:

def train_students(self,epochs=20,plot_losses=True,save_model=True,save_model_path="./models/student.pth",):
    for student in self.student_cohort:
        student.train()
    loss_arr = []
    length_of_dataset = len(self.train_loader.dataset)
    num_students = len(self.student_cohort                      )
    print("\nTraining students...")
    for ep in range(epochs):
        epoch_loss = 0.0
        correct = 0
        for optimizer in self.student_optimizers:
            lr = adjust_learning_rate(optimizer, ep, self.lr, epochs, lr_type='cos')
        for (data, label) in self.train_loader:
            data = data.to(self.device)
            label = label.to(self.device)
            for optim in self.student_optimizers:
                optim.zero_grad()
            avg_student_loss = 0
            for i in range(num_students):
                student_loss = 0
                for j in range(num_students):
                    if i == j:
                        continue
                    student_loss += self.loss_fn(
                        self.student_cohort[i](data), self.student_cohort[j](data)
                    )
                student_loss /= num_students - 1
                student_loss += F.cross_entropy(self.student_cohort[i](data), label)
                student_loss.backward()
                self.student_optimizers[i].step()
                avg_student_loss += student_loss
            avg_student_loss /= num_students

            predictions = []
            correct_preds = []
            for i, student in enumerate(self.student_cohort):
                predictions.append(student(data).argmax(dim=1, keepdim=True))
                correct_preds.append(
                    predictions[i].eq(label.view_as(predictions[i])).sum().item()
                )

            correct += max(correct_preds)

            epoch_loss += avg_student_loss

        epoch_acc = correct / length_of_dataset

        if self.log:
            self.writer.add_scalar("Training loss/Student", epoch_loss, epochs)
            self.writer.add_scalar("Training accuracy/Student", epoch_acc, epochs)

        loss_arr.append(epoch_loss)
        print(f"Epoch: {ep + 1}, Loss: {epoch_loss}, Training Accuracy: {epoch_acc}")

    return lr

这里的student_cohort就是刚刚传入的nets_cohort,显然student_number=2。这里两个网络的KL损失使用的是nn.MSELoss(),存在了self.loss_fn中,在上述代码也有体现。

服务器的代码如下:

def cloud_update(l_ks,g_k,train_loader,test_loader,lr, n_epoch,device):
    ensemble = AvgEnsemble(l_ks)
    # teh_optimizer = optim.SGD(ensemble.parameters(), lr)
    stu_optimizer = optim.SGD(g_k.parameters(), lr)
    kd_agent = Distiller(lr=lr,epochs=n_epoch,device=device)
    _, _,_, lr = kd_agent.pure_kd(ensemble,g_k,None,stu_optimizer,train_loader,test_loader,device='cuda')
    return lr

首先是Ensemble的提取:

class AvgEnsemble(nn.Module):
    def __init__(self, net_list):
        super(AvgEnsemble, self).__init__()
        self.estimators = nn.ModuleList(net_list)

    def forward(self, x):
        outputs = [
            F.softmax(estimator(x), dim=1) for estimator in self.estimators
        ]
        proba = average(outputs)

        return proba

可以发现就是把服务器样本依次输入每个客户端模型,然后取输出的平均。最后和客户端更新类似,如下:

def pure_kd(self, teacher, student, teh_optimizer, stu_optimizer, train_loader, test_loader,
            device='cuda',s_save_path=None):
    '''
        Pure distillation. Distillate the teacher's knowledge to student.
        Only student will be updated.
        :param teacher_model (torch.nn.Module): Teacher model
        :param student_model (torch.nn.Module): Student model
        :param train_loader (torch.utils.data.DataLoader): Dataloader for training
        :param val_loader (torch.utils.data.DataLoader): Dataloader for validation/testing
        :param optimizer_teacher (torch.optim.*): Optimizer used for training teacher
        :param optimizer_student (torch.optim.*): Optimizer used for training student
    '''


    if self.kd_type == 'VanillaKD':
        kd = VanillaKD(teacher, student, train_loader, test_loader,
                       teh_optimizer, stu_optimizer,lr=self.lr ,device=device)
    lr = kd.train_student(epochs=self.epochs, plot_losses=False, save_model=False)

    stu_acc = kd.evaluate(teacher=False)  # eval student
    return kd.teacher_model, kd.student_model,stu_acc, lr
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值