FedAvg(联邦平均)的算法实现
理解了数据集的构成之后,下面我们就到了最重要的部分:FedAvg算法的实现。
做联邦学习的小伙伴肯定都知道,FedAvg思想的精髓就在于将各用户自己训练的权重整合起来进行平均(根据所拥有的样本数的不同,在平均过程中所占比重也不同)。
所以我们重点内容就是找到FedML模型在哪里收集了权重,并进行了整合,我们接下来所做的工作(比如聚类再执行联邦平均)就是要改这部分的源代码。
首先是载入数据,上一篇已经详解了这一部分。然后就是创建训练所用的模型,框架将这一部分拿了出来降低耦合度,方便我们扩展和调用。
最后就是初始化FedAvgTrainer,执行train()。
# load data
dataset = load_data(args, args.dataset) # args.dataset = 'mnist'
# create model.
model = create_model(args, model_name=args.model, output_dim=dataset[7]) # dataset[7] = class_num
def create_model(args, model_name, output_dim):
model = None
if model_name == "lr" and args.dataset == "mnist":
logging.info("LogisticRegression + MNIST")
model = LogisticRegression(28 * 28, output_dim)
return model
trainer = FedAvgTrainer(dataset, model, device, args)
trainer.train()
FedAvgTrainer初始化的东西(部分):
self.model = model
self.model.train() # 使用BatchNormalizetion()和Dropout()
self.client_list = []
self.setup_clients(train_data_local_num_dict, train_data_local_dict, test_data_local_dict)
def setup_clients(self, train_data_local_num_dict, train_data_local_dict, test_data_local_dict):
logging.info("############setup_clients (START)#############")
for client_idx in range(self.args.client_num_per_round): # 每一轮参与的用户数
c = Client(client_idx, train_data_local_dict[client_idx], test_data_local_dict[client_idx],
train_data_local_num_dict[client_idx], self.args, self.device, self.model)
self.client_list.append(c)
logging.info("############setup_clients (END)#############")
# 抽样,每一轮选取部分用户参与联邦平均
def client_sampling(self, round_idx, client_num_in_total, client_num_per_round):
if client_num_in_total == client_num_per_round:
client_indexes = [client_index for client_index in range(client_num_in_total)]
else:
num_clients = min(client_num_per_round, client_num_in_total)
np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round
client_indexes = np.random.choice(range(client_num_in_total), num_clients, replace=False)
logging.info("client_indexes = %s" % str(client_indexes))
return client_indexes
之后就是最关键的核心代码,FedAvgTrainer的train函数:
def train(self):
w_global = self.model.state_dict()
for round_idx in range(self.args.comm_round): # 根据通信轮数开始通信
logging.info("################Communication round : {}".format(round_idx))
w_locals, loss_locals = [], []
"""
for scalability: following the original FedAvg algorithm, we uniformly sample a fraction of clients in each round.
Instead of changing the 'Client' instances, our implementation keeps the 'Client' instances and then updates their local dataset
"""
client_indexes = self.client_sampling(round_idx, self.args.client_num_in_total,
self.args.client_num_per_round)
logging.info("client_indexes = " + str(client_indexes))
for idx, client in enumerate(self.client_list):
# update dataset 根据随机替换setup好的
client_idx = client_indexes[idx]
client.update_local_dataset(client_idx, self.train_data_local_dict[client_idx],
self.test_data_local_dict[client_idx],
self.train_data_local_num_dict[client_idx])
# train on new dataset
w, loss = client.train(w_global)
# self.logger.info("local weights = " + str(w))
w_locals.append((client.get_sample_number(), copy.deepcopy(w)))
# 在此一个个收集权重
loss_locals.append(copy.deepcopy(loss))
logging.info('Client {:3d}, loss {:.3f}'.format(client_idx, loss))
# update global weights 在此整合
w_global = self.aggregate(w_locals)
# logging.info("global weights = " + str(w_glob))
# print loss
loss_avg = sum(loss_locals) / len(loss_locals)
logging.info('Round {:3d}, Average loss {:.3f}'.format(round_idx, loss_avg))
if round_idx % self.args.frequency_of_the_test == 0 or round_idx == self.args.comm_round - 1:
self.model.load_state_dict(w_global)
self.local_test_on_all_clients(self.model, round_idx)
聚合部分,这一部分可以依照自己的想法修改,比如改变联邦平均的策略方式。
w_locals是列表,存的是元祖(用户样例数,权重)。
def aggregate(self, w_locals):
training_num = 0
for idx in range(len(w_locals)):
(sample_num, averaged_params) = w_locals[idx]
training_num += sample_num
(sample_num, averaged_params) = w_locals[0]
for k in averaged_params.keys():
for i in range(0, len(w_locals)):
local_sample_number, local_model_params = w_locals[i]
w = local_sample_number / training_num
if i == 0:
averaged_params[k] = local_model_params[k] * w
else:
averaged_params[k] += local_model_params[k] * w
return averaged_params
Client.py的train()过程
class Client:
def train(self, w_global):
self.model.load_state_dict(w_global)
self.model.to(self.device)
# train and update
if self.args.client_optimizer == "sgd":
optimizer = torch.optim.SGD(self.model.parameters(), lr=self.args.lr)
else:
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.args.lr,
weight_decay=self.args.wd, amsgrad=True)
epoch_loss = []
for epoch in range(self.args.epochs):
batch_loss = []
for batch_idx, (x, labels) in enumerate(self.local_training_data):
x, labels = x.to(self.device), labels.to(self.device)
self.model.zero_grad()
log_probs = self.model(x)
loss = self.criterion(log_probs, labels)
loss.backward()
optimizer.step()
batch_loss.append(loss.item())
epoch_loss.append(sum(batch_loss) / len(batch_loss))
return self.model.cpu().state_dict(), sum(epoch_loss) / len(epoch_loss)
def local_test(self, model_global, b_use_test_dataset=False):
model_global.eval()
model_global.to(self.device)
metrics = {
'test_correct': 0,
'test_loss': 0,
'test_precision': 0,
'test_recall': 0,
'test_total': 0
}
if b_use_test_dataset:
test_data = self.local_test_data
else:
test_data = self.local_training_data
with torch.no_grad():
for batch_idx, (x, target) in enumerate(test_data):
x = x.to(self.device)
target = target.to(self.device)
pred = model_global(x)
loss = self.criterion(pred, target)
_, predicted = torch.max(pred, -1)
correct = predicted.eq(target).sum()
metrics['test_correct'] += correct.item()
metrics['test_loss'] += loss.item() * target.size(0)
metrics['test_total'] += target.size(0)
return metrics
所以最后总结FedML的伪分布式并没有进程同步或者异步的概念,只是一个个处理单一用户然后收集处理。所以要想真正了解联邦学习环境下的计算,分布式的环境和FedML独有的Mobile环境是之后必须搭建和学习的。
但是如果想要改进算法来提升最后的acc,伪分布式的环境就可以用来实验。