Dirichlet分布生成联邦学生non-iid数据

1. 根据数据集训练数据标签和客户端数量,生成每个客户端上具有dirichlet分布的数据索引。
def dirichlet_split_noniid(train_labels, alpha, n_clients):
    '''
    按照参数为alpha的Dirichlet分布将样本索引集合划分为n_clients个子集
    '''
    n_classes = train_labels.max()+1
    # (K, N) 类别标签分布矩阵X,记录每个类别划分到每个client去的比例
    label_distribution = np.random.dirichlet([alpha]*n_clients, n_classes)
    # (K, ...) 记录K个类别对应的样本索引集合
    class_idcs = [np.argwhere(train_labels == y).flatten()
                  for y in range(n_classes)]
    # 记录N个client分别对应的样本索引集合
    client_idcs = [[] for _ in range(n_clients)]
    for k_idcs, fracs in zip(class_idcs, label_distribution):
        # np.split按照比例fracs将类别为k的样本索引k_idcs划分为了N个子集
        # i表示第i个client,idcs表示其对应的样本索引集合idcs
        for i, idcs in enumerate(np.split(k_idcs,
                                          (np.cumsum(fracs)[:-1]*len(k_idcs)).
                                          astype(int))):
            client_idcs[i] += [idcs]
    client_idcs = [np.concatenate(idcs) for idcs in client_idcs]
    return client_idcs
2. 调用函数根据索引产生训练数据集
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

def load_dataset_fmnist():
    mnist_mean, mnist_std = 0.1307, 0.3081
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((mnist_mean,), (mnist_std,))])

    # transform = transforms.Compose([transforms.ToTensor()])

    fmnist_train_dataset = datasets.FashionMNIST(root="./data/FMNIST/data", train=True, transform=transform, download=True)
    fmnist_test_dataset = datasets.FashionMNIST(root="./data/FMNIST/data", train=False, transform=transform, download=True)

    return fmnist_train_dataset, fmnist_test_dataset

trainset, testset = load_dataset_fmnist()
labels = trainset.targets[:]
classes = trainset.classes
n_classes = len(classes)
dirichlet_alpha = 0.1
n_clients = 4
client_idcs = dirichlet_split_noniid(trainset.targets, alpha = dirichlet_alpha, n_clients= n_clients)
3. 绘制不同客户端上的数据标签分布情况
plt.figure(figsize=(12, 8))
label_distribution = [[] for _ in range(n_classes)]
for c_id, idc in enumerate(client_idcs):
    for idx in idc:
        label_distribution[labels[idx]].append(c_id)

plt.hist(label_distribution, stacked=True,
            bins=np.arange(-0.5, n_clients + 1.5, 1),
            label=classes, rwidth=0.5)
plt.xticks(np.arange(n_clients), ["Client %d" %
                                    c_id for c_id in range(n_clients)])
plt.xlabel("Client ID")
plt.ylabel("Number of samples")
plt.legend()
plt.title("Display Label Distribution on Different Clients")
plt.show()

参考资料:
联邦学习:按Dirichlet分布划分Non-IID样本
病态非独立同分布

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值