在传统的机器学习中,数据同在一个中心,不会出现什么非独立同分布的问题;然而在联邦学习中,每个客户端(client)都拥有自己的数据集,大家各不相同,所以数据不独立同分布是常态。因此我们在做实验时,需要模拟真实的场景,对一个数据集进行
Non-IID
的划分。这里参考网上的资料,按
Dirichlet
分布划分Non-IID
数据集。首先处理常见的非结构化数据集,之后再处理结构化数据集。
什么是Non-IID
非独立同分布包括以下三种:
- 不独立但同分布
- 独立但不同分布
- 不独立也不同分布
也就是说,除了独立同分布(independent and identically distributed),其余都是 Non-IID
。
独立没什么好说的,关键在于同分布。
举个栗子,对于标准数据集 cifar-10
,该数据集有 6w 张图片,分为 10 类,每类均为6k张图片。在做传统的图像分类实验中,数据集采用均匀划分的 5w 个作为训练集,1w 个样本作为测试集。我们把训练集和测试集看作两个数据集,他们各自有 10 个类别,每类都占有 1/10 ,也就是他们的标签分布都为 1:1:1……,对于两个数据集的类别数量比相同,这就叫同分布(IID)。
此时我们再想想,为什么多个客户端提供的数据集具有 Non-IID
的性质,因为不同数据集的不同类数量比一般都不同,你提供一个 100:150:200 的数据集,我提供一个 200:301:402的数据集,看起来近似,但分布就是不一样。
如何进行划分
这里先对非结构化的数据集进行划分,之后再处理结构化数据集
非结构化:图像、文本、 视频
结构化:类似于 json 格式的或数据库表那样格式的数据
比较常见的是根据样本的标签分布进行 Non-IID
的划分。
思路如下:
尽量让每个 client 上的样本标签分布不同。我们设有 K 个类别标签, N 个 client,每个类别标签的样本需要按照不同的比例划分在不同的 client 上。我们需要一个类别标签分布的矩阵,其行向量表示类别 k 在不同 client 上的概率分布向量(显然每个行向量的和为1),该随机向量就采样自 Dirichlet
分布。

numpy中的dirichlet函数
def dirichlet(alpha, size=None):
参数:
alpha
: 对应分布函数中的参数向量 α ,长度为 k 。
size
: 为输出形状大小,因为采出的每个样本是一个随机向量,默认最后一维会自动加上 k ,如果给定形状为 (m,n) ,那么 m×n 个维度为 k 的随机向量会从中抽取。默认为 None,即返回一个一个 k 维的随机向量。
返回:
out
: ndarray 采出的样本,大小为 (size,k) 。
这里的 α 越小,得到的差异越大;α 越大,差异越小,也就是越平均。
K 其实就对应着 client 的数量。
函数使用案例
设 α=(10,5,3) (意味着 k=3 ), size=(2,2) ,则采出的样本为 2×2 个维度为 k=3 的随机向量。
import numpy as np
s = np.random.dirichlet((10, 5, 3), size=(2, 2))
print(s)
因为调用了 random
,所以每次打印的结果都是不同的。在实验时应该设置随机数种子,保证相同的数据拆分,即结果的可复现。

划分代码实现
首先再来理一理思路,我们的目的是:将每个类别划分为 N 个子集。
- 拿到数据集对应标签的所有下标
- 分类
- 得到每个类别的所有下标
- 将每个类别用
dirichlet
函数划分为 N 份 - 得到划分后的标签下标
import numpy as np
# 设置随机数种子,保证相同的数据拆分,可获得结果的复现
np.random.seed(42)
def dirichlet_split_noniid(train_labels, alpha, n_clients):
'''
参数为 alpha 的 Dirichlet 分布将数据索引划分为 n_clients 个子集
'''
# 总类别数
n_classes = train_labels.max()+1
# [alpha]*n_clients 如下:
# [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
# 得到 62 * 10 的标签分布矩阵,记录每个 client 占有每个类别的比率
label_distribution = np.random.dirichlet([alpha]*n_clients, n_classes)
# 记录每个类别对应的样本下标
# 返回二维数组
class_idcs = [np.argwhere(train_labels==y).flatten()
for y in range(n_classes)]
# 定义一个空列表作最后的返回值
client_idcs = [[] for _ in range(n_clients)]
# 记录N个client分别对应样本集合的索引
for c, fracs in zip(class_idcs, label_distribution):
# np.split按照比例将类别为k的样本划分为了N个子集
# for i, idcs 为遍历第i个client对应样本集合的索引
for i, idcs in enumerate(np.split(c, (np.cumsum(fracs)[:-1]*len(c)).astype(int))):
client_idcs[i] += [idcs]
client_idcs = [np.concatenate(idcs) for idcs in client_idcs]
return client_idcs
有很多方法对我这 AI 小白并不友好。。记录如下:
- np.argwhere(a > 1):返回数组中大于 1 的下标。
- zip:将对应的元素打包成一个个元组,然后返回由这些元组组成的列表。
- enumerate:将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。
- np.cumsum:特别抽象的累加

- np.concatenate:连接数组

测试
在 EMNIST
数据集上调用该函数进行测试,并进行可视化呈现。