非独立同分布抽样。以CIFAR数据集为例,Python。

本文介绍了一个将CIFAR10数据集划分为非独立同分布(Non-IID)的函数,通过为每个用户随机分配不同碎片以模拟真实世界数据异质性,不同于独立同分布(I.I.D.)的采样方法,适用于分布式和联邦学习环境。
摘要由CSDN通过智能技术生成

Non-IID(非独立同分布):与IID相反,Non-IID数据可能存在某些依赖关系或可能来自不同的分布。在此代码片段中,可以选择“相等”或“不等”分割,分别对应于所有用户之间平均分配的数据(但不是同分布的),或某些用户可能获得更多或更少数据的情况。

def cifar_noniid(dataset, num_users):
    """
    Sample non-I.I.D client data from CIFAR10 dataset
    :param dataset:
    :param num_users:
    :return:
    """
    num_shards, num_imgs = 200, 250
    idx_shard = [i for i in range(num_shards)]
    dict_users = {i: np.array([]) for i in range(num_users)}
    idxs = np.arange(num_shards*num_imgs)
    # labels = dataset.train_labels.numpy()
    labels = np.array(dataset.train_labels)

    # sort labels
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]

    # divide and assign
    for i in range(num_users):
        rand_set = set(np.random.choice(idx_shard, 2, replace=False))
        idx_shard = list(set(idx_shard) - rand_set)
        for rand in rand_set:
            dict_users[i] = np.concatenate(
                (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
    return dict_users
  1. 变量初始化

    • num_shards, num_imgs = 200, 250:将数据集分为200个碎片(shards),每个碎片包含250个图像。
    • idx_shard:包含200个碎片索引的列表。
    • dict_users:一个字典,其中的键是用户索引,值是一个空numpy数组,用于存储每个用户分配的图像索引。
    • idxs:包含数据集中所有图像索引的numpy数组。
  2. 获取标签并排序

    • 从数据集中获取标签,并将它们与对应的图像索引一起存储在idxs_labels中。
    • 使用idxs_labels[1, :].argsort()按标签对图像索引进行排序,确保具有相同标签的图像彼此相邻。
  3. 分配非I.I.D.数据

    • 对于每个用户,从idx_shard中随机选择2个碎片(无放回),然后从索引列表中删除这些碎片。
    • 对于每个选择的碎片,使用np.concatenate将碎片的图像索引添加到该用户的索引数组中。

总结

这个函数实现了一种特定的非I.I.D.数据分配策略。通过将数据集分成碎片,并确保每个用户都从不同的碎片中获得数据,它可以模拟现实世界中的场景,例如不同的用户可能只访问到数据集的一部分。这与I.I.D.采样相反,在I.I.D.采样中,每个数据点独立地具有相同的选择概率,无论用户之间如何分配。在许多分布式机器学习和联邦学习场景中,非I.I.D.数据分配是一种常见的挑战。


如果是初学者,尤其是没有系统学过Python的,有些代码还真的得一句句理解,顺便学习一下Python的语法。比如

np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
  • dict_users[i]: 这是一个NumPy数组,包含已分配给第i个用户的图像索引。

  • idxs[rand*num_imgs:(rand+1)*num_imgs]: 这部分代码选择了一个特定碎片中的图像索引。这里rand是随机选择的碎片索引,num_imgs是每个碎片中的图像数量。通过乘法和切片,这个表达式选择了所需碎片中的所有图像索引。

  • np.concatenate((..., ...), axis=0): 这个函数接受一个包含两个数组的元组,并在指定的轴(在这种情况下是轴0)上将它们拼接在一起。因为dict_users[i]idxs[...]都是一维数组,所以axis=0表示它们将端对端地连接在一起。

这个拼接操作的结果是一个新的NumPy数组,包含了第i个用户已分配的图像索引以及新选择的碎片中的图像索引。然后将这个新数组分配给dict_users[i],更新该用户的图像索引集合。

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值