基于PyTorch实现联邦学习的基本算法FedAvg

I. 前言

在之前的一篇博客联邦学习基本算法FedAvg的代码实现中利用numpy手搭神经网络实现了FedAvg,手搭的神经网络效果已经很好了,不过这还是属于自己造轮子,建议优先使用PyTorch来实现。

II. 数据介绍

联邦学习中存在多个客户端,每个客户端都有自己的数据集,这个数据集他们是不愿意共享的。

本文选用的数据集为中国北方某城市十个区/县从2016年到2019年三年的真实用电负荷数据,采集时间间隔为1小时,即每一天都有24个负荷值。

我们假设这10个地区的电力部门不愿意共享自己的数据,但是他们又想得到一个由所有数据统一训练得到的全局模型。

除了电力负荷数据以外,还有一个备选数据集:风功率数据集。两个数据集通过参数type指定:type == 'load’表示负荷数据,'wind’表示风功率数据。

特征构造

用某一时刻前24个时刻的负荷值以及该时刻的相关气象数据(如温度、湿度、压强等)来预测该时刻的负荷值。

对于风功率数据,同样使用某一时刻前24个时刻的风功率值以及该时刻的相关气象数据来预测该时刻的风功率值。

各个地区应该就如何制定特征集达成一致意见,本文使用的各个地区上的数据的特征是一致的,可以直接使用。

III. 联邦学习

1. 整体框架

原始论文中提出的FedAvg的框架为:
在这里插入图片描述
客户端模型采用PyTorch搭建:

class ANN(nn.Module):
    def __init__(self, args, name):
        super(ANN, self).__init__()
        self.name = name
        self.len = 0
        self.loss = 0
        self.sigmoid = nn.Sigmoid()
        self.fc1 = nn.Linear(args.input_dim, 16)
        self.fc2 = nn.Linear(16, 32)
        self.fc3 = nn.Linear(32, 16)
        self.fc4 = nn.Linear(16, 1)

    def forward(self, data):
        x = self.fc1(data)
        x = self.sigmoid(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        x = self.fc3(x)
        x = self.sigmoid(x)
        x = self.fc4(x)
        x = self.sigmoid(x)

        return x

2. 服务器端

服务器端执行以下步骤:

  1. 初始化参数
  2. 对第 t t t轮训练来说:首先计算出 m = m a x ( C ⋅ K , 1 ) m=max(C \cdot K, 1) m=max(CK,1),然后随机选择 m m m个客户端,对这 m m m个客户端做如下操作(所有客户端并行执行):更新本地的 w t k w_t^{k} wtk得到 w t + 1 k w_{t+1}^{k} wt+1k m m m个客户端更新结束后,将 w t + 1 k w_{t+1}^{k} wt+1k传到服务器,服务器整合所有 w t + 1 k w_{t+1}^{k} wt+1k得到最新的全局参数 w t + 1 w_{t+1} wt+1

简单来说,每一轮通信时都只是选择部分客户端,这些客户端利用本地的数据进行参数更新,然后将更新后的参数传给服务器,服务器汇总客户端更新后的参数形成最新的全局参数。下一轮通信时,服务器端将最新的参数分发给被选中的客户端,进行下一轮更新。

3. 客户端

客户端没什么可说的,就是利用本地数据对神经网络模型的参数进行更新。

IV. 代码实现

1. 初始化

class FedAvg:
    def __init__(self, args):
        self.args = args
        self.clients = args.clients
        self.nn = ANN(args, name='server').to(args.device)
        self.nns = []
        for i in range(args.K):
            temp = copy.deepcopy(self.nn)
            temp.name = self.clients[i]
            self.nns.append(temp)

参数保存在args中:

  1. K,客户端数量,本文为10个,也就是10个地区。
  2. C:选择率,每一轮通信时都只是选择C * K个客户端。
  3. E:客户端更新本地模型的参数时,在本地数据集上训练E轮。
  4. B:客户端更新本地模型的参数时,本地数据集batch大小为B
  5. r:服务器端和客户端一共进行r轮通信。
  6. clients:客户端集合。
  7. type:指定数据类型,负荷预测or风功率预测。
  8. lr:学习率。
  9. input_dim:数据输入维度。
  10. nn:全局模型。
  11. nns: 客户端模型集合。

2. 服务器端

服务器端代码如下:

def server(self):
     for t in range(self.r):
          print('第', t + 1, '轮通信:')
          m = np.max([int(self.C * self.K), 1])
          # sampling
          index = random.sample(range(0, self.K), m)
          # dispatch
          self.dispatch(index)
          # local updating
          self.client_update(index)
          # aggregation
          self.aggregation(index)
     # return global model
     return self.nn

其中client_update(index):

def client_update(self, index):  # update nn
     for k in index:
          self.nns[k] = train(self.nns[k])

aggregation(index):

def aggregation(self, index):
     s = 0
     for j in index:
          # normal
          s += self.nns[j].len
          
     params = {}
     with torch.no_grad():
          for k, v in self.nns[0].named_parameters():
               params[k] = copy.deepcopy(v)
               params[k].zero_()
     for j in index:
          with torch.no_grad():
               for k, v in self.nns[j].named_parameters():
                    params[k] += v * (self.nns[j].len / s)
     with torch.no_grad():
          for k, v in self.nn.named_parameters():
               v.copy_(params[k])

dispatch(index):

def dispatch(self, index):
     params = {}
     with torch.no_grad():
          for k, v in self.nn.named_parameters():
               params[k] = copy.deepcopy(v)
     for j in index:
          with torch.no_grad():
               for k, v in self.nns[j].named_parameters():
                    v.copy_(params[k])

下面对重要代码进行分析:

  • 客户端的选择
m = np.max([int(self.C * self.K), 1])
index = random.sample(range(0, self.K), m)

index中存储中m个0~10间的整数,表示被选中客户端的序号。

  • 客户端的更新
for k in index:
    self.client_update(self.nns[k])
  • 服务器端汇总客户端模型的参数

关于模型汇总方式,可以参考一下我的另一篇文章:对FedAvg中模型聚合过程的理解

当然,这只是一种很简单的汇总方式,还有一些其他类型的汇总方式。论文Electricity Consumer Characteristics Identification: A Federated Learning Approach中总结了三种汇总方式:

  1. normal:原始论文中的方式,即根据样本数量来决定客户端参数在最终组合时所占比例。
  2. LA:根据客户端模型的损失占所有客户端损失和的比重来决定最终组合时参数所占比例。
  3. LS:根据损失与样本数量的乘积所占的比重来决定。
  • 将更新后的参数分发给被选中的客户端
def dispatch(self, index):
     params = {}
     with torch.no_grad():
          for k, v in self.nn.named_parameters():
               params[k] = copy.deepcopy(v)
     for j in index:
          with torch.no_grad():
               for k, v in self.nns[j].named_parameters():
                    v.copy_(params[k])

3. 客户端

客户端只需要利用本地数据来进行更新就行了:

def client_update(self, index):  # update nn
     for k in index:
          self.nns[k] = train(self.nns[k])

4. 测试

def global_test(self):
     model = self.nn
     model.eval()
     c = clients if self.type == 'load' else clients_wind
     for client in c:
          model.name = client
          test(model)

V. 实验及结果

本次实验的参数选择为:

KCEBr
100.550505
def main():
    args = args_parser()
    fed = FedAvg(args)
    fed.server()
    fed.global_test()

各个客户端单独训练(训练50轮,batch大小为50)后在本地的测试集上的表现为:

客户端编号12345678910
MAPE / %5.334.113.034.203.022.702.942.992.304.10

可以看到,由于各个客户端的数据都十分充足,所以每个客户端自己训练的本地模型的预测精度已经很高了。

服务器与客户端通信5轮后,服务器上的全局模型在10个客户端测试集上的表现如下所示:

客户端编号12345678910
MAPE / %6.844.543.565.113.754.474.303.903.154.58

可以看到,经过联邦学习框架得到全局模型在各个客户端上表现同样很好,这是因为十个地区上的数据分布类似。

给出numpy和PyTorch的对比:

客户端编号12345678910
本地5.334.113.034.203.022.702.942.992.304.10
numpy6.584.193.175.133.584.694.713.752.944.77
PyTorch6.844.543.565.113.754.474.303.903.154.58

同样本地模型的效果是最好的,PyTorch搭建的网络和numpy搭建的网络效果差不多,但推荐使用PyTorch,不要造轮子。

VI. 源码及数据

后面将陆续公开~

  • 24
    点赞
  • 107
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 14
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Cyril_KI

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值