FedML伪分布式下FedAvg(联邦平均)的算法实现

FedAvg(联邦平均)的算法实现

理解了数据集的构成之后,下面我们就到了最重要的部分:FedAvg算法的实现。

做联邦学习的小伙伴肯定都知道,FedAvg思想的精髓就在于将各用户自己训练的权重整合起来进行平均(根据所拥有的样本数的不同,在平均过程中所占比重也不同)。

w_{t+1} \leftarrow \sum_{k=1}^{K} \frac{n_{k}}{n} w_{t+1}^{k}

所以我们重点内容就是找到FedML模型在哪里收集了权重,并进行了整合,我们接下来所做的工作(比如聚类再执行联邦平均)就是要改这部分的源代码。


首先是载入数据,上一篇已经详解了这一部分。然后就是创建训练所用的模型,框架将这一部分拿了出来降低耦合度,方便我们扩展和调用。

最后就是初始化FedAvgTrainer,执行train()。

# load data
dataset = load_data(args, args.dataset)  # args.dataset = 'mnist'

# create model.
model = create_model(args, model_name=args.model, output_dim=dataset[7])  # dataset[7] = class_num

def create_model(args, model_name, output_dim):
    model = None
    if model_name == "lr" and args.dataset == "mnist":
        logging.info("LogisticRegression + MNIST")
        model = LogisticRegression(28 * 28, output_dim)
    return model

trainer = FedAvgTrainer(dataset, model, device, args)
trainer.train()

FedAvgTrainer初始化的东西(部分):

self.model = model
self.model.train()  # 使用BatchNormalizetion()和Dropout()

self.client_list = []
self.setup_clients(train_data_local_num_dict, train_data_local_dict, test_data_local_dict)
def setup_clients(self, train_data_local_num_dict, train_data_local_dict, test_data_local_dict):
    logging.info("############setup_clients (START)#############")
    for client_idx in range(self.args.client_num_per_round):  # 每一轮参与的用户数
        c = Client(client_idx, train_data_local_dict[client_idx], test_data_local_dict[client_idx],
                   train_data_local_num_dict[client_idx], self.args, self.device, self.model)
        self.client_list.append(c)
    logging.info("############setup_clients (END)#############")

# 抽样,每一轮选取部分用户参与联邦平均
def client_sampling(self, round_idx, client_num_in_total, client_num_per_round):
    if client_num_in_total == client_num_per_round:
        client_indexes = [client_index for client_index in range(client_num_in_total)]
    else:
        num_clients = min(client_num_per_round, client_num_in_total)
        np.random.seed(round_idx)  # make sure for each comparison, we are selecting the same clients each round
        client_indexes = np.random.choice(range(client_num_in_total), num_clients, replace=False)
    logging.info("client_indexes = %s" % str(client_indexes))
    return client_indexes

之后就是最关键的核心代码,FedAvgTrainer的train函数: 

def train(self):
    w_global = self.model.state_dict()
    for round_idx in range(self.args.comm_round):  # 根据通信轮数开始通信
        logging.info("################Communication round : {}".format(round_idx))

        w_locals, loss_locals = [], []

        """
        for scalability: following the original FedAvg algorithm, we uniformly sample a fraction of clients in each round.
        Instead of changing the 'Client' instances, our implementation keeps the 'Client' instances and then updates their local dataset 
        """
        client_indexes = self.client_sampling(round_idx, self.args.client_num_in_total,
                                              self.args.client_num_per_round)
        logging.info("client_indexes = " + str(client_indexes))

        for idx, client in enumerate(self.client_list):
            # update dataset 根据随机替换setup好的
            client_idx = client_indexes[idx]
            client.update_local_dataset(client_idx, self.train_data_local_dict[client_idx],
                                        self.test_data_local_dict[client_idx],
                                        self.train_data_local_num_dict[client_idx])

            # train on new dataset
            w, loss = client.train(w_global)
            # self.logger.info("local weights = " + str(w))
            w_locals.append((client.get_sample_number(), copy.deepcopy(w)))
            # 在此一个个收集权重
            loss_locals.append(copy.deepcopy(loss))
            logging.info('Client {:3d}, loss {:.3f}'.format(client_idx, loss))

        # update global weights 在此整合
        w_global = self.aggregate(w_locals)
        # logging.info("global weights = " + str(w_glob))

        # print loss
        loss_avg = sum(loss_locals) / len(loss_locals)
        logging.info('Round {:3d}, Average loss {:.3f}'.format(round_idx, loss_avg))

        if round_idx % self.args.frequency_of_the_test == 0 or round_idx == self.args.comm_round - 1:
            self.model.load_state_dict(w_global)
            self.local_test_on_all_clients(self.model, round_idx)

聚合部分,这一部分可以依照自己的想法修改,比如改变联邦平均的策略方式。

w_locals是列表,存的是元祖(用户样例数,权重)。

def aggregate(self, w_locals):
    training_num = 0
    for idx in range(len(w_locals)):
        (sample_num, averaged_params) = w_locals[idx]
        training_num += sample_num

    (sample_num, averaged_params) = w_locals[0]
    for k in averaged_params.keys():
        for i in range(0, len(w_locals)):
            local_sample_number, local_model_params = w_locals[i]
            w = local_sample_number / training_num
            if i == 0:
                averaged_params[k] = local_model_params[k] * w
            else:
                averaged_params[k] += local_model_params[k] * w
    return averaged_params

Client.py的train()过程

class Client:

    def train(self, w_global):
        self.model.load_state_dict(w_global)
        self.model.to(self.device)

        # train and update
        if self.args.client_optimizer == "sgd":
            optimizer = torch.optim.SGD(self.model.parameters(), lr=self.args.lr)
        else:
            optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.args.lr,
                                              weight_decay=self.args.wd, amsgrad=True)

        epoch_loss = []
        for epoch in range(self.args.epochs):
            batch_loss = []
            for batch_idx, (x, labels) in enumerate(self.local_training_data):
                x, labels = x.to(self.device), labels.to(self.device)
                self.model.zero_grad()
                log_probs = self.model(x)
                loss = self.criterion(log_probs, labels)
                loss.backward()
                optimizer.step()
                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss) / len(batch_loss))
        return self.model.cpu().state_dict(), sum(epoch_loss) / len(epoch_loss)

        def local_test(self, model_global, b_use_test_dataset=False):
        model_global.eval()
        model_global.to(self.device)
        metrics = {
            'test_correct': 0,
            'test_loss': 0,
            'test_precision': 0,
            'test_recall': 0,
            'test_total': 0
        }
        if b_use_test_dataset:
            test_data = self.local_test_data
        else:
            test_data = self.local_training_data
        with torch.no_grad():
            for batch_idx, (x, target) in enumerate(test_data):
                x = x.to(self.device)
                target = target.to(self.device)
                pred = model_global(x)
                loss = self.criterion(pred, target)

                _, predicted = torch.max(pred, -1)
                correct = predicted.eq(target).sum()

                metrics['test_correct'] += correct.item()
                metrics['test_loss'] += loss.item() * target.size(0)
                metrics['test_total'] += target.size(0)

        return metrics

所以最后总结FedML的伪分布式并没有进程同步或者异步的概念,只是一个个处理单一用户然后收集处理。所以要想真正了解联邦学习环境下的计算,分布式的环境和FedML独有的Mobile环境是之后必须搭建和学习的。

但是如果想要改进算法来提升最后的acc,伪分布式的环境就可以用来实验。

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值