Non-IID(非独立同分布):与IID相反,Non-IID数据可能存在某些依赖关系或可能来自不同的分布。在此代码片段中,可以选择“相等”或“不等”分割,分别对应于所有用户之间平均分配的数据(但不是同分布的),或某些用户可能获得更多或更少数据的情况。
def cifar_noniid(dataset, num_users):
"""
Sample non-I.I.D client data from CIFAR10 dataset
:param dataset:
:param num_users:
:return:
"""
num_shards, num_imgs = 200, 250
idx_shard = [i for i in range(num_shards)]
dict_users = {i: np.array([]) for i in range(num_users)}
idxs = np.arange(num_shards*num_imgs)
# labels = dataset.train_labels.numpy()
labels = np.array(dataset.train_labels)
# sort labels
idxs_labels = np.vstack((idxs, labels))
idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
idxs = idxs_labels[0, :]
# divide and assign
for i in range(num_users):
rand_set = set(np.random.choice(idx_shard, 2, replace=False))
idx_shard = list(set(idx_shard) - rand_set)
for rand in rand_set:
dict_users[i] = np.concatenate(
(dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
return dict_users
-
变量初始化:
num_shards, num_imgs = 200, 250
:将数据集分为200个碎片(shards),每个碎片包含250个图像。idx_shard
:包含200个碎片索引的列表。dict_users
:一个字典,其中的键是用户索引,值是一个空numpy数组,用于存储每个用户分配的图像索引。idxs
:包含数据集中所有图像索引的numpy数组。
-
获取标签并排序:
- 从数据集中获取标签,并将它们与对应的图像索引一起存储在
idxs_labels
中。 - 使用
idxs_labels[1, :].argsort()
按标签对图像索引进行排序,确保具有相同标签的图像彼此相邻。
- 从数据集中获取标签,并将它们与对应的图像索引一起存储在
-
分配非I.I.D.数据:
- 对于每个用户,从
idx_shard
中随机选择2个碎片(无放回),然后从索引列表中删除这些碎片。 - 对于每个选择的碎片,使用
np.concatenate
将碎片的图像索引添加到该用户的索引数组中。
- 对于每个用户,从
总结
这个函数实现了一种特定的非I.I.D.数据分配策略。通过将数据集分成碎片,并确保每个用户都从不同的碎片中获得数据,它可以模拟现实世界中的场景,例如不同的用户可能只访问到数据集的一部分。这与I.I.D.采样相反,在I.I.D.采样中,每个数据点独立地具有相同的选择概率,无论用户之间如何分配。在许多分布式机器学习和联邦学习场景中,非I.I.D.数据分配是一种常见的挑战。
如果是初学者,尤其是没有系统学过Python的,有些代码还真的得一句句理解,顺便学习一下Python的语法。比如
np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
-
dict_users[i]
: 这是一个NumPy数组,包含已分配给第i
个用户的图像索引。 -
idxs[rand*num_imgs:(rand+1)*num_imgs]
: 这部分代码选择了一个特定碎片中的图像索引。这里rand
是随机选择的碎片索引,num_imgs
是每个碎片中的图像数量。通过乘法和切片,这个表达式选择了所需碎片中的所有图像索引。 -
np.concatenate((..., ...), axis=0)
: 这个函数接受一个包含两个数组的元组,并在指定的轴(在这种情况下是轴0)上将它们拼接在一起。因为dict_users[i]
和idxs[...]
都是一维数组,所以axis=0
表示它们将端对端地连接在一起。
这个拼接操作的结果是一个新的NumPy数组,包含了第i
个用户已分配的图像索引以及新选择的碎片中的图像索引。然后将这个新数组分配给dict_users[i]
,更新该用户的图像索引集合。