联邦学习数据集划分Dirichlet划分法——pytorch实现

该文介绍了在联邦学习中,如何使用PyTorch而非numpy实现Dirichlet分布进行数据集划分。通过Dirichlet函数生成标签分布,按比例分配到各个客户端,强调了torch.split与numpy.split的区别。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

联邦学习数据集Dirichlet划分

做联邦学习数据集划分的时候,一般要考虑到数据的特异性,我们一般使用dirichlet分布来产生不同的客户端数据。
网上找的资料大部分都是numpy实现的dirichlet划分,但是因为强迫症 不想额外引入numpy,这里将介绍一下torch如何实现dirichlet划分的方法:
完整代码如下:
参数:

  1. train_labels: 数据集的标签列表
  2. dirichlet分布参数
  3. n_clients:有几个客户端需要分配

小小的解释一下:
首先我们使用Dirichlet函数返回了一个标签分布的矩阵tensor,这个tensor的维度是特征数X客户端数,每一行就是一个标签在不同客户端上的分布,总和为1。然后我们获得每一个标签的下标class_idcs。获得了这两个矩阵之后,我们只需要循环遍历每一个标签,就是:
for c, fracs in zip(class_idcs, label_distribution):
每次取出一个标签的下标位置,以及在每一个客户端的分配比例。通过分配比例拆分此标签,获得每一个客户端拥有的此标签的下标。

import torch
from torch.distributions.dirichlet import Dirichlet


def dirichlet_split_noniid(train_labels, alpha, n_clients):
    n_classes = train_labels.max() + 1
    label_distribution = Dirichlet(torch.full((n_clients,), alpha)).sample((n_classes,))
    # 1. Get the index of each label
    class_idcs = [torch.nonzero(train_labels == y).flatten()
                  for y in range(n_classes)]
    # 2. According to the distribution, the label is assigned to each client
    client_idcs = [[] for _ in range(n_clients)]

    for c, fracs in zip(class_idcs, label_distribution):
        total_size = len(c)
        splits = (fracs * total_size).int()
        splits[-1] = total_size - splits[:-1].sum()
        idcs = torch.split(c, splits.tolist())
        for i, idx in enumerate(idcs):
            client_idcs[i] += [idcs[i]]

    client_idcs = [torch.cat(idcs) for idcs in client_idcs]
    return client_idcs

唯一的坑点我觉得的就是,torch的split和numpy的split居然不一样。
numpy的split是按照累积和来分解,例如要将1~100分解成1 ~ 10, 10 ~ 50,numpy输入的数组是[10,50],而torch是按照实际大小,输入数组是[10, 40, 50]。不过个人感觉torch这个就直观很多,每一个拆分有多少数据。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

volcanical

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值