参考代码,https://github.com/shaoxiongji/federated-learning
将普通的训练代码,更改为模拟联邦学习的代码,选用最简单的方式,使用for循环模拟客户端,采用Fedavg进行聚合,进而实现联邦学习的模拟。
- 核心代码
#准备训练
global_model.to(args.device)
global_model.train()
# 复制当前全局模型global_model的权重
global_weights = global_model.state_dict()
# training
loss_train = []
#联邦学习算法的主要训练循环
for epoch in range(args.epochs): # epochs全局轮次,全局训练循环
local_weights, local_losses = [], [] #每个客户端的本地权重 本地损失
print(f'\n | Global Training Round : {epoch + 1} |\n')
global_model.train()
# 计算每轮训练将选择的客户端的数量。 num_users全部的客户端,frac选择的比例
m = max(int(args.frac * args.num_users), 1)
#随机选择客户端
idxs_users = np.random.choice(range(args.num_users), m, replace=False)
#利用for循环实现每个客户端的本地训练
#遍历被选中的客户端
for idx in idxs_users:
loader_class = build_dist_loaders if not args.data_path else build_dataloaders
train_dataset, valid_dataset, train_loader, valid_loader, train_sampler, valid_sampler = loader_class(args,
tokenizer,
logger,
idx)
#得到客户端本地训练后的权重和损失
w, loss = train(args=args, model=copy.deepcopy(global_model), idx=idx, train_dataset=train_dataset,
valid_dataset=valid_dataset, tokenizer=tokenizer, logger=logger)
local_weights.append(copy.deepcopy(w))
local_losses.append(copy.deepcopy(loss))
#更新全局权重,使用FedAvg
global_weights = FedAvg(local_weights)
#将更新后的全局权重global_weights 加载到全局模型global_model中,以便在下一轮迭代中使用
global_model.load_state_dict(global_weights)
print("local_losses:", local_losses)
#计算本轮训练的平均损失
loss_avg = sum(local_losses) / len(local_losses)
print('Round {:3d}, Average loss {:.3f}'.format(epoch, loss_avg))
loss_train.append(loss_avg)
- Fedavg代码
def FedAvg(w):
w_avg = copy.deepcopy(w[0])
for k in w_avg.keys():
for i in range(1, len(w)):
w_avg[k] += w[i][k]
w_avg[k] = torch.div(w_avg[k], len(w))
return w_avg
以上就可根据自己的训练代码,进行模拟联邦学习的修改。在for循环中,写入自己的train函数即可。但要对数据加载进行处理,对数据进行分割,将数据分成n份(模拟n个客户端),可以按照索引对数据进行区分,在模拟不同客户端训练时,传入该客户端对应的数据进行训练。