联邦学习数据集划分Dirichlet划分法及其可视化

文章目录

  • 前言
  • 图片效果:
    • 独立同分布效果
    • 非独立同分布效果
  • 一、参数
    • 输入
    • 输出
  • 二、代码
    • 可视化:
    • 标签划分:
    • 代码调用


前言

用于实现并控制联邦学习客户端之间数据集非独立同分布,并将效果可视化


图片效果:

独立同分布效果

  1. 对不同类别的分配效果可视化:
    在这里插入图片描述
  2. 对不同客户端拥有的数据集的可视化:
    在这里插入图片描述

非独立同分布效果

  1. 对不同类别的分配效果可视化:
    在这里插入图片描述
  2. 对不同客户端拥有的数据集的可视化:
    在这里插入图片描述

一、参数

输入

  • classes:标签名称,列表类型
    -示例:[‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’]
  • train_labels:数据集标签,列表类型
    -示例:[6 9 9 … 9 1 1]
  • alpha:浓度参数,浮点数数值,(0,+∞)
  • client_number:客户端数量,整型数值

输出

client_idcs:各客户端拥有的数据图片下标,列表类型
示例:

client_idcs=[array([  29,   30,   35, ..., 9676, 9683, 9701]), 
 array([ 9171,  9181,  9193, ..., 20167, 20172, 20176]), 
 array([18920, 18925, 18935, ..., 29604, 29609, 29628]),
 array([28887, 28897, 28912, ..., 38602, 38621, 38644]), 
 array([39601, 39606, 39619, ..., 49963, 49971, 49997])]

二、代码

可视化:

def draw_dataset(classes,labels,client_idcs,num_users):
    #设置图片保存位置
    # 构建save文件夹的路径
    # 获取当前文件的父目录  
    parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))   
    save_dir = os.path.join(parent_dir, 'save/img') 
    # 如果save文件夹不存在,则创建它  
    if not os.path.exists(save_dir):  
        os.makedirs(save_dir)  
    file_path1 = os.path.join(save_dir, '1.png')
    file_path2 = os.path.join(save_dir, '2.png')
    
    # 展示不同label划分到不同client的情况
    n_classes = 10#cifar10有10个类别
    plt.figure(figsize=(12, 8))
    plt.hist([labels[idc]for idc in client_idcs], stacked=True,
             bins=np.arange(min(labels)-0.5, max(labels) + 1.5, 1),
             label=["Client {}".format(i) for i in range(num_users)],
             rwidth=0.5)
    plt.xticks(np.arange(n_classes), classes)
    plt.xlabel("Label type")
    plt.ylabel("Number of samples")
    plt.legend(loc="upper right")
    plt.title("Display Label Distribution on Different Clients")
    plt.savefig(file_path1)

    # 展示不同client上的label分布
    plt.figure(figsize=(12, 8))
    label_distribution = [[] for _ in range(n_classes)]
    for c_id, idc in enumerate(client_idcs):
        for idx in idc:
            label_distribution[labels[idx]].append(c_id)

    plt.hist(label_distribution, stacked=True,
                bins=np.arange(-0.5, num_users + 1.5, 1),
                label=classes, rwidth=0.5)
    plt.xticks(np.arange(num_users), ["Client %d" %
                                        c_id for c_id in range(num_users)])
    plt.xlabel("Client ID")
    plt.ylabel("Number of samples")
    plt.legend()
    plt.title("Display Label Distribution on Different Clients")
    plt.savefig(file_path2)

标签划分:

def dirichlet_split_noniid(classes,train_labels, alpha=100.0, client_number=5):
    '''
    参数为 alpha 的 Dirichlet 分布将数据索引划分为 n_clients 个子集
    '''
    # 总类别数
    n_classes = train_labels.max()+1#也可以自己手动设置
    label_distribution = np.random.dirichlet([alpha]*n_clients, n_classes)
    # 记录每个类别对应的样本下标
    # 返回二维数组
    class_idcs = [np.argwhere(train_labels==y).flatten()
           for y in range(n_classes)]

    # 定义一个空列表作最后的返回值
    client_idcs = [[] for _ in range(n_clients)]
    # 记录N个client分别对应样本集合的索引
    for c, fracs in zip(class_idcs, label_distribution):
        # np.split按照比例将类别为k的样本划分为了N个子集
        # for i, idcs 为遍历第i个client对应样本集合的索引
        for i, idcs in enumerate(np.split(c, (np.cumsum(fracs)[:-1]*len(c)).astype(int))):
            client_idcs[i] += [idcs]
    client_idcs = [np.concatenate(idcs) for idcs in client_idcs]
    draw_dataset(classes,train_labels,client_idcs, n_clients)
    
    return client_idcs

代码调用

train_dataset = datasets.CIFAR10(data_dir, train=True, download=True, transform=trans_cifar10_train)
train_client_idcs = dirichlet_split_noniid(train_dataset.classes,np.array(train_dataset.targets),alpha=100.0,n_clients=5)

  • 9
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

虫本初阳

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值