联邦学习数据集Dirichlet划分
做联邦学习数据集划分的时候,一般要考虑到数据的特异性,我们一般使用dirichlet分布来产生不同的客户端数据。
网上找的资料大部分都是numpy实现的dirichlet划分,但是因为强迫症 不想额外引入numpy,这里将介绍一下torch如何实现dirichlet划分的方法:
完整代码如下:
参数:
- train_labels: 数据集的标签列表
- dirichlet分布参数
- n_clients:有几个客户端需要分配
小小的解释一下:
首先我们使用Dirichlet函数返回了一个标签分布的矩阵tensor,这个tensor的维度是特征数X客户端数,每一行就是一个标签在不同客户端上的分布,总和为1。然后我们获得每一个标签的下标class_idcs。获得了这两个矩阵之后,我们只需要循环遍历每一个标签,就是:
for c, fracs in zip(class_idcs, label_distribution):
每次取出一个标签的下标位置,以及在每一个客户端的分配比例。通过分配比例拆分此标签,获得每一个客户端拥有的此标签的下标。
import torch
from torch.distributions.dirichlet import Dirichlet
def dirichlet_split_noniid(train_labels, alpha, n_clients):
n_classes = train_labels.max() + 1
label_distribution = Dirichlet(torch.full((n_clients,), alpha)).sample((n_classes,))
# 1. Get the index of each label
class_idcs = [torch.nonzero(train_labels == y).flatten()
for y in range(n_classes)]
# 2. According to the distribution, the label is assigned to each client
client_idcs = [[] for _ in range(n_clients)]
for c, fracs in zip(class_idcs, label_distribution):
total_size = len(c)
splits = (fracs * total_size).int()
splits[-1] = total_size - splits[:-1].sum()
idcs = torch.split(c, splits.tolist())
for i, idx in enumerate(idcs):
client_idcs[i] += [idcs[i]]
client_idcs = [torch.cat(idcs) for idcs in client_idcs]
return client_idcs
唯一的坑点我觉得的就是,torch的split和numpy的split居然不一样。
numpy的split是按照累积和来分解,例如要将1~100分解成1 ~ 10, 10 ~ 50,numpy输入的数组是[10,50],而torch是按照实际大小,输入数组是[10, 40, 50]。不过个人感觉torch这个就直观很多,每一个拆分有多少数据。