FedProto代码复现

本文不对代码具体分析,具体思路在主页FedProto论文阅读中 

完整代码

import copy
from argparse import Namespace

import torch
import torchvision
import numpy as np
from torch import nn, optim, distributions, softmax
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.nn.functional as F

# ----------------------------------数据管理----------------------------------
from tqdm import tqdm


class GetData:
    def __init__(self, isIID, num_clients, dev, beta=0.4, BatchSize=64):
        self.dev = dev
        self.beta = beta
        self.is_iid = isIID
        self.num_of_clients = num_clients  # 客户端的数量
        self.BatchSize = BatchSize
        # self.clients_set = {} # 用于整合客户端信息

    def load_data(self):
        # 加载MNIST数据集
        train_dataset = torchvision.datasets.MNIST("./dataset", train=True, download=False)
        test_dataset = torchvision.datasets.MNIST("./dataset", train=False, download=False)
        # 将数据与标签分离
        train_data = train_dataset.data.to(torch.float)
        train_labels = np.array(train_dataset.targets)

        test_data = test_dataset.data.to(torch.float)
        test_labels = test_dataset.targets
        # 对数据进行归一化处理, 使得数据的均值为0, 标准差为1
        mean = (train_data.mean()) / (train_data.max() - train_data.min())
        std = (train_data.std()) / (train_data.max() - train_data.min())

        transform = transforms.Compose([
            transforms.Normalize((mean,), (std,))
        ])

        train_data = transform(train_data)
        test_data = transform(test_data)

        # train_data.shpe[0] <-> len(train_data) 为数据集的大小
        train_data_size = train_data.shape[0]
        test_data_size = test_data.shape[0]

        test_DataLoader = DataLoader(TensorDataset(test_data, test_labels), batch_size=self.BatchSize)

        # 整体数据集(MNIST)中的类别数
        nclass = np.max(train_labels) + 1

        # 数据划分
        client_train_data = {}
        client_train_label = {}
        distribution = {}
        if self.is_iid:
            # 设置随机种子数可以保证每次运行代码时,np.random.permutation(self.train_data_size)产生的数列都是一样的, 这样可以确保实验的可重复性
            np.random.seed(12)

            # 对训练数据集(序号)进行随机排列, 得到一个索引数组idxs(实现了将数据集打乱)
            idxs = np.random.permutation(train_data_size)
            # 将索引数组idxs分割成与客户端数量相等的子数组batch_idxs
            batch_idxs = np.array_split(idxs, self.num_of_clients)
            # 遍历所有客户端
            for i in range(self.num_of_clients):
                # 根据索引数组batch_idxs[i], 给每个客户端分配相应的数据和标签
                client_train_data[i] = train_data[batch_idxs[i]]
                client_train_label[i] = train_labels[batch_idxs[i]]

            return client_train_data, client_train_label, test_DataLoader
        else:
            n_clients = self.num_of_clients
            train_label = train_labels

            np.random.seed(123)
            # [self.beta] * n_clients会创建一个长度为n_clients, 每个元素都为self.beta的列表
            # 生成一个形状为(nclass, n_clients)的矩阵, 记录每个类别划分到每个client的比例
            label_distribution = np.random.dirichlet([self.beta] * n_clients, nclass)

            # 对于每一个在范围[0, nclass-1]内的整数y, 找到train_label中所有等于y的元素的索引, 并将这些索引展平成一个一维数组
            # class_idcs是一个列表, 其中每个元素都是一个一维数组, 记录每个类别对应的样本索引
            class_idcs = [np.argwhere(train_label == y).flatten() for y in range(nclass)]

            # 创建一个名为client_idcs的列表, 其中包含n_clients个空列表, 每个空列表代表一个客户端的索引集合
            client_idcs = [[] for _ in range(n_clients)]

            # 使用zip函数将class_idcs和label_distribution进行配对, 并迭代处理每一对数据
            # 每次迭代, 都是一个不同的类别, c是从class_ids中取出的一个一维数组, 包含了本次迭代对应类别的样本索引,
            # fracs是从label_distribution中取出的一维数组, 表示当前类别的数据分配到各个客户端的比例
            for c, fracs in zip(class_idcs, label_distribution):
                # np.split按照比例将本次迭代对应类别的样本划分为了N个子集
                # (np.cumsum(fracs)[:-1] * len(c) 先计算了fracs数组的累积和, 然后[:-1]切片操作去除累积和数组的最后一个元素,以确保
                # fracs与新的累加和数组的长度相同, 最后累加和数组的每个元素乘以数组c的长度会得到一个新的数组
                # 将新数组的每个元素转换为整数后即为分割c的分割点
                # for i, idcs 为遍历第i个client对应样本集合的索引
                for i, idcs in enumerate(np.split(c, (np.cumsum(fracs)[:-1] * len(c)).astype(int))):
                    client_idcs[i] += [idcs]

            for i in range(self.num_of_clients):
                idcs = client_idcs[i]
                # 记录每个client拥有(数据)的样本数量
                distribution = [len(c) for c in idcs]

                client_train_data[i] = train_data[np.concatenate(idcs)]
                client_train_label[i] = train_label[np.concatenate(idcs)]
                # yield client_train_data[i], client_train_label[i]

            return client_train_data, client_train_label, test_DataLoader


# ----------------------------------定义CNN神经网络模型----------------------------------
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, stride=1, padding=2)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(7 * 7 * 64, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, inputs):
        tensor = inputs.view(-1, 1, 28, 28)
        tensor = F.relu(self.conv1(tensor))
        tensor = self.pool1(tensor)
        tensor = F.relu(self.conv2(tensor))
        tensor = self.pool2(tensor)
        tensor = tensor.view(-1, 7 * 7 * 64)
        tensor = F.relu(self.fc1(tensor))
        features = tensor
        tensor = self.fc2(tensor)
        return tensor, features


# ----------------------------------客户端----------------------------------
class client:
    def __init__(self, mu, device):
        super(client, self).__init__()
        self.device = device
        self.mu = mu

    def loc_update(self, **kwargs):
        online_clients = kwargs['online_clients']
        net = kwargs['net']
        trainDataSet = kwargs['trainDataSet']
        global_protos = kwargs['global_protos']
        local_epoch = kwargs['local_epoch']
        localBatchSize = kwargs['localBatchSize']

        train_loader = DataLoader(trainDataSet, batch_size=localBatchSize, shuffle=True)
        # 对client的本地模型进行更新
        local_protos = self.train_net(net, train_loader, global_protos, local_epoch)
        return local_protos

    # local_epoch = 2
    def train_net(self, net, train_loader, global_protos, local_epoch):
        net = net.to(self.device)
        # 创建SGD优化器
        # lr: 本地训练的学习率
        # momentum=self.cfg.OPTIMIZER.momentum: 动量, 用于加速SGD在相关方向上的收敛, 并抑制震荡
        # weight_decay=self.cfg.OPTIMIZER.weight_decay: 权重衰减(L2正则化项), 用于防止过拟合
        optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
        # 设置损失函数
        criterion = nn.CrossEntropyLoss()
        criterion.to(self.device)
        # 显示进度条, 表示本地训练的轮次
        # iterator = tqdm(range(self.cfg.OPTIMIZER.local_epoch))
        agg_protos_label = {}
        protos_label_num = {}
        # 在本地进行训练
        for iter in range(local_epoch):
            # 遍历train_loader, batch_idx表示索引
            for images, labels in train_loader:
                # ---------------------------------梯度清零------------------------------
                optimizer.zero_grad()

                images = images.to(self.device)
                labels = labels.to(self.device)
                outputs, f = net(images)
                # ---------------------------------计算损失------------------------------
                # 计算交叉熵损失, 表示预测值和真实标签之间的差异
                lossCE = criterion(outputs, labels)
                # 提取特征向量f
                # f = net.features(images)
                # 计算原型损失(通过均方差损失MSEloss)
                loss_mse = nn.MSELoss()
                # 如果全局原型为0, 损失为0
                if len(global_protos) == 0:
                    lossProto = 0 * lossCE
                else:
                    # 拷贝特征表示
                    f_new = copy.deepcopy(f.data)
                    i = 0
                    # 遍历client本地数据集包含的标签
                    for label in labels:
                        # 当前标签是否存在于全局原型字典中(对应client本地的类别)
                        if label.item() in global_protos.keys():
                            # 使用全局原型替换相应的特征表示
                            f_new[i, :] = global_protos[label.item()][0].data
                        i += 1
                    # 计算替换后的特征表示与原始特征表示之间的均方误差损失(实质是求本地原型与全局原型之间的均方差损失)
                    lossProto = loss_mse(f_new, f)
                # 将原型损失乘以一个权重系数(self.mu), 以调整其在总损失中的重要性
                lossProto = lossProto * self.mu

                # 总损失
                loss = lossCE + lossProto
                # ---------------------------------反向传播------------------------------
                loss.backward()
                # iterator.desc = "Local Pariticipant %d CE = %0.3f,Proto = %0.3f" % (index, lossCE, lossProto)
                # # ---------------------------------优化器权重更新------------------------------
                optimizer.step()

                # ---------------------------------聚合原型------------------------------
                # 在本地训练打到最后一轮时进行
                if iter == local_epoch - 1:
                    # 禁用梯度计算, 因为不需要在这个步骤中进行反向传播
                    with torch.no_grad():
                        for i in range(len(labels)):
                            # 当前标签是否已经存在于全局原型字典agg_protos_label中
                            if labels[i].item() in agg_protos_label:
                                # 将当前样本的特征向量累加到对应的全局原型上, 并增加该标签的计数器protos_label_num[labels[i].item()]
                                agg_protos_label[labels[i].item()] += copy.deepcopy(f[i, :].detach())
                                protos_label_num[labels[i].item()] += 1
                            else:
                                # 将当前样本的特征向量作为新的全局原型添加到字典agg_protos_label中, 并将该标签的计数器初始化为1
                                agg_protos_label[labels[i].item()] = copy.deepcopy(f[i, :].detach())
                                protos_label_num[labels[i].item()] = 1

        # ---------------------------------计算每个类别的局部原型------------------------------
        # 在联邦学习中, 每个客户端只拥有部分数据, 因此需要计算每个类别在这些局部数据上的平均值作为局部原型, 这样可以确保每个类别的原型能够准确地表示该类别的数据特征, 而不受单个客户端数据量大小的影响
        for label in agg_protos_label:
            agg_protos_label[label] = agg_protos_label[label] / protos_label_num[label]

        return agg_protos_label


class FederatedAggregation:
    """
    Federated Aggregation
    """

    def weight_calculate(self, **kwargs):
        online_clients = kwargs['online_clients']
        # priloader_list = kwargs['priloader_list']
        clients_data = kwargs['clients_data']
        clients_label = kwargs['clients_label']

        online_clients_dl = []
        for online_clients_index in range(online_clients):
            train_data = clients_data[online_clients_index]
            train_label = torch.tensor(clients_label[online_clients_index])
            dataset = TensorDataset(train_data, train_label)
            loader = DataLoader(dataset, batch_size=64)
            online_clients_dl.append(loader)
        online_clients_len = [len(dl) for dl in online_clients_dl]
        online_clients_all = np.sum(online_clients_len)
        freq = online_clients_len / online_clients_all
        return freq


# ----------------------------------服务器----------------------------------
class server:
    def __init__(self):
        super(server, self).__init__()

    def proto_aggregation(self, freq, online_clients, local_protos_dict):
        # 初始化一个空字典agg_protos_label, 用于存储聚合后的全局原型
        agg_protos_label = {}
        # 遍历客户端
        for idx in range(online_clients):
            # 获取每个客户端的局部原型
            local_protos = local_protos_dict[idx]
            # 将局部原型添加到全局原型中
            for label in local_protos.keys():
                if label in agg_protos_label:
                    agg_protos_label[label].append(local_protos[label] * freq[idx])
                else:
                    agg_protos_label[label] = [local_protos[label] * freq[idx]]

        # 取平均值作为最终的全局原型
        for [label, proto_list] in agg_protos_label.items():
            if len(proto_list) > 1:
                proto = 0
                for i in proto_list:
                    proto += i.data
                agg_protos_label[label] = [proto / len(proto_list)]
            else:
                agg_protos_label[label] = [proto_list[0].data]

        return agg_protos_label

    def server_update(self, **kwargs):
        online_clients = kwargs['online_clients']
        clients_data = kwargs['clients_data']
        clients_label = kwargs['clients_label']
        local_protos = kwargs['local_protos']

        fed_aggregation = FederatedAggregation()
        # 计算每个在线客户端的权重, 因为异构, 所以在后续聚合时使用加权平均的方法, 各客户端的重要性比例由其数据量和类别分布决定
        freq = fed_aggregation.weight_calculate(online_clients=online_clients, clients_data=clients_data,
                                                clients_label=clients_label)

        # 加权平均
        global_protos = self.proto_aggregation(freq, online_clients, local_protos)
        return global_protos


# ----------------------------------在测试集上评估模型的性能, 计算准确率和平均损失----------------------------------
class test_accuracy:

    def test_accuracy(self, net, testDataLoader, dev):
        criterion = nn.CrossEntropyLoss()
        criterion.to(dev)
        # 存储损失
        loss_collector = []
        with torch.no_grad():
            sum_accu = 0
            num = 0
            loss_collector.clear()
            # 载入测试集
            for data, label in testDataLoader:
                data, label = data.to(dev), label.to(dev)
                output, _ = net(data)
                loss = criterion(output, label)
                # loss = 1
                loss_collector.append(loss.item())
                output = torch.argmax(output, dim=1)
                sum_accu += (output == label).float().mean()
                num += 1

            accuracy = sum_accu / num
            avg_loss = sum(loss_collector) / len(loss_collector)
        return avg_loss, accuracy


if __name__ == "__main__":
    # ----------------------------------设置参数----------------------------------
    dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    # ----------------------------------创建clients, 并分配数据----------------------------------
    IID = False
    # 假设有10个客户端
    get_data = GetData(isIID=IID, num_clients=10, dev=dev)
    clients_data, clients_label, testDataLoader = get_data.load_data()

    # ----------------------------------初始化模型----------------------------------
    # 通信轮数
    rounds = 50
    # batch_size
    batch_size = 64
    # 客户端数量
    num_in_comm = 10
    # 原型损失权重系数
    mu = 2
    global_loss = 0
    global_acc = 0

    # ----------------------------------训练----------------------------------
    writer = SummaryWriter("logs")
    # 定义变量global_protos, 初始的global_protos是一个空字典
    global_protos = {}
    client_model = []
    for i in range(num_in_comm):
        net = SimpleCNN()
        client_model.append(net)
    # clients与server之间通信
    for curr_round in range(1, rounds + 1):
        local_loss = []
        client_protos = {}
        for k in range(num_in_comm):
            my_client = client(mu, dev)
            train_data = clients_data[k]
            train_label = torch.tensor(clients_label[k])
            # 每个client训练得到的原型
            local_protos = my_client.loc_update(online_clients=num_in_comm, net=client_model[k],
                                                trainDataSet=TensorDataset(train_data, train_label),
                                                local_epoch=2, global_protos=global_protos, localBatchSize=batch_size)
            client_protos[k] = local_protos
            accuracy = test_accuracy()
            local_loss, local_acc = accuracy.test_accuracy(client_model[k], testDataLoader, dev)
            global_loss += local_loss
            global_acc += local_acc
            print(
                '[Round: %d Client: %d] |test : accuracy: %f  loss: %f ' % (curr_round, k, local_acc, local_loss))

        global_loss, global_acc = global_loss / num_in_comm, global_acc / num_in_comm
        writer.add_scalar("global_loss", global_loss, curr_round)
        writer.add_scalar("global_acc", global_acc, curr_round)
        print(
            '----------------------------------[Round: %d] accuracy: %f  loss: %f----------------------------------'
            % (curr_round, global_acc, global_loss))
        # 取平均值,得到本次通信中server得到的更新后的模型参数
        s = server()
        global_protos = s.server_update(online_clients=num_in_comm, clients_data=clients_data,
                                        clients_label=clients_label, local_protos=client_protos)

运行结果

在FedProto中,本地模型的更新是基于原型的,这意味着每个客户端都会根据自己的数据计算出一个本地原型,并将其发送到服务器端。服务器端则负责聚合这些本地原型,得到全局原型。然而,由于全局原型是由多个本地原型聚合得到的,它本身并不直接对应于一个具体的全局模型,因此不能直接用于更新全局模型。

为了解决这个问题,FedProto采用了一种间接的方式来评估全局模型的性能。具体来说,服务器端会计算所有本地模型的准确率和损失的平均值,并将这些值作为全局准确率和全局损失的参考。这种方法可以提供一个大致的全局模型性能估计,但需要注意,这并不代表真实的全局模型性能

  • 7
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值