模拟联邦学习

参考代码,https://github.com/shaoxiongji/federated-learning

将普通的训练代码,更改为模拟联邦学习的代码,选用最简单的方式,使用for循环模拟客户端,采用Fedavg进行聚合,进而实现联邦学习的模拟。

  • 核心代码
    #准备训练
    global_model.to(args.device)
    global_model.train()

    # 复制当前全局模型global_model的权重
    global_weights = global_model.state_dict()

    # training
    loss_train = []

    #联邦学习算法的主要训练循环
    for epoch in range(args.epochs):  # epochs全局轮次,全局训练循环
        local_weights, local_losses = [], []  #每个客户端的本地权重  本地损失

        print(f'\n | Global Training Round : {epoch + 1} |\n')
        global_model.train()

        # 计算每轮训练将选择的客户端的数量。 num_users全部的客户端,frac选择的比例
        m = max(int(args.frac * args.num_users), 1)
        #随机选择客户端
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)
        #利用for循环实现每个客户端的本地训练
        #遍历被选中的客户端
        for idx in idxs_users:
            loader_class = build_dist_loaders if not args.data_path else build_dataloaders
            train_dataset, valid_dataset, train_loader, valid_loader, train_sampler, valid_sampler = loader_class(args,
                                                                                                                  tokenizer,
                                                                                                                  logger,
                                                                                                                  idx)
        #得到客户端本地训练后的权重和损失
            w, loss = train(args=args, model=copy.deepcopy(global_model), idx=idx, train_dataset=train_dataset,
                            valid_dataset=valid_dataset, tokenizer=tokenizer, logger=logger)


            local_weights.append(copy.deepcopy(w))
            local_losses.append(copy.deepcopy(loss))

        #更新全局权重,使用FedAvg
        global_weights = FedAvg(local_weights)
        #将更新后的全局权重global_weights 加载到全局模型global_model中,以便在下一轮迭代中使用
        global_model.load_state_dict(global_weights)

        print("local_losses:", local_losses)
        #计算本轮训练的平均损失
        loss_avg = sum(local_losses) / len(local_losses)
        print('Round {:3d}, Average loss {:.3f}'.format(epoch, loss_avg))

        loss_train.append(loss_avg)

  • Fedavg代码
def FedAvg(w):
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[k] += w[i][k]
        w_avg[k] = torch.div(w_avg[k], len(w))
    return w_avg

以上就可根据自己的训练代码,进行模拟联邦学习的修改。在for循环中,写入自己的train函数即可。但要对数据加载进行处理,对数据进行分割,将数据分成n份(模拟n个客户端),可以按照索引对数据进行区分,在模拟不同客户端训练时,传入该客户端对应的数据进行训练。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值