这段时间都在更新持续学习,今天在网上找到一篇关于联邦学习中不同模型的融合问题的论文,在这里分享给大家,论文地址点这里。
一. 介绍
联邦学习(FL)已成为一种新的机器学习范式,用于分布式客户端参与集中式模型的协作训练。FL将模型异步培训引入边缘,设备(如手机和物联网设备)提取关于私有敏感培训数据的知识,然后将学习到的模型上传到云端进行聚合。FL在本地存储用户数据,并限制从云服务器直接访问;因此,这种模式不仅增强了隐私保护,而且引入了一些固有的优点,包括模型准确性、成本效率和多样性。随着当今机器学习模型对数据的巨大需求和人工智能的社会考量(安全性和隐私性)联合学习在权衡这一权衡方面具有巨大的潜力和作用。
常规的联邦学习对所有的模型权重进行加权平均,但是受到边缘设备数据异构的情况,可能会产生不公平、无效的全局模型,并且无法有效地部署。为了克服先前FL方法的上述局限性,本文提出了使用知识提取和多模型融合(FedKEMF)执行资源感知联邦学习的想法,如下图所示:
二. 方法
2.1 使用深度互学习网络进行知识提取
深度互学习的关键是训练多个神经网络同时最小化网络输出的Kullback Leibler。换句话说,KL散度评估了两种分布的相似性。通过最小化神经网络之间的KL发散,使得它们可以相互学习知识。因此,为了从本地模型中提取知识,我们在本地引入了一个知识网络(与本地模型相比是一个较小的网络),并使用深度互学习同时优化知识网络和本地模型。
为了直观地解释知识提取过程,我们在图像分类任务中解释了它。形式上,在边缘客户端中,我们有一个本地模型
θ
\theta
θ和一个知识网络(微型网络)
θ
g
\theta_g
θg。首先,更新局部模型
θ
\theta
θ:对于任何一批输入数据
x
x
x,我们使用交叉熵损失去计算预测值和真实标签的损失:
L
c
=
−
∑
i
=
1
N
y
T
log
(
σ
(
θ
(
x
i
)
)
)
L_c=-\sum_{i=1}^N y^T \log \left(\sigma\left(\theta\left(x_i\right)\right)\right)
Lc=−i=1∑NyTlog(σ(θ(xi)))
其中
N
N
N表示为mini-batch的大小,
σ
\sigma
σ为softmax函数。
然后我们计算
θ
(
x
)
\theta(x)
θ(x)和
θ
g
(
x
)
\theta_g(x)
θg(x)之间的KL距离:
D
K
L
(
θ
g
∥
θ
)
=
∑
i
=
1
N
σ
(
θ
g
(
x
)
T
)
log
(
σ
(
θ
g
(
x
)
)
σ
(
θ
(
x
)
)
)
D_{K L}\left(\theta_g \| \theta\right)=\sum_{i=1}^N \sigma\left(\theta_g(x)^T\right) \log \left(\frac{\sigma\left(\theta_g(x)\right)}{\sigma(\theta(x))}\right)
DKL(θg∥θ)=i=1∑Nσ(θg(x)T)log(σ(θ(x))σ(θg(x)))
那么更新
θ
\theta
θ的总体损失如下:
L
θ
=
L
c
+
D
K
L
(
θ
g
∥
θ
)
L_\theta=L_c+D_{K L}\left(\theta_g \| \theta\right)
Lθ=Lc+DKL(θg∥θ)
下面是本地更新的算法:
2.2 多模型知识融合
在FedKEMF中,我们提供了两种模型融合方法,用于边缘知识的服务器融合。第一个类似于传统的FL,因为我们合计了重量。其次,受(Lin et al.2020)启发,我们将所有收到的客户模型进行集成,并将集成知识提取到全球知识网络中。在本节和实验中,我们主要关注整合客户的知识。然而,FedKEMF也可以使用传统的融合方法。
定义集成模型为
Θ
=
{
θ
g
k
}
k
∈
S
\Theta=\{\theta^k_g\}_{k\in S}
Θ={θgk}k∈S,其中
θ
g
k
\theta^k_g
θgk表示为第
k
k
k个客户端的知识网络,
S
S
S为和当前服务器通信的客户端集合。然后,通过使用服务器中的未标记数据、生成数据或公共数据,我们将集合
Θ
\Theta
Θ的知识蒸馏到一个全局知识网络
θ
g
\theta_g
θg,如下:
L
d
=
D
K
L
(
Θ
,
θ
g
)
L_d=D_{K L}\left(\Theta, \theta_g\right)
Ld=DKL(Θ,θg)
具体算法描述如下:
2.3 知识集成
在FedKEMF中,我们研究了三种集成策略,即最大logits、平均logits和多数投票。由于在实际应用中,采用最大对数作为集成策略效果最好。对于给定的输入实例
x
x
x,集成模型的计算公式如下:
Θ
(
x
)
=
Ensemble
Max
(
{
θ
g
k
(
x
)
}
k
∈
S
)
\Theta(x)=\text { Ensemble }_{\operatorname{Max}}\left(\left\{\theta_g^k(x)\right\}_{k \in S}\right)
Θ(x)= Ensemble Max({θgk(x)}k∈S)
三. 代码解析
代码链接点这里
在这篇代码的工作中,作者在resnet20、32、44三种模型中随机分配给不同的客户端:
def init_fl(n_parties,model_name, args):
if args.env == 'multi-model':
nets = {net_i: None for net_i in range(n_parties)}
net_name = ['resnet20','resnet32','resnet44']
for net_i in range(n_parties):
import random
net_type = net_name[random.randint(0, 2)]
net = resnet.__dict__[net_type]()
if args.ckpt_path is not None:
# path = os.path.join(data_root, "pretrained_models",'resnet56-4bfd9763.th')
checkpoint = torch.load(args.ckpt_path, map_location=args.device)
sd = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
net.load_state_dict(sd)
net = torch.nn.DataParallel(net)
nets[net_i] = net
return nets, None, None
else:
return init_nets(n_parties,model_name, args)
之后每个客户端的知识模型使用的是一样的,这里不再赘述。
在训练中,首先是各个客户端的训练:
def local_update(nets:dict, g_k:nn.Module, selected, args, net_dataidx_map, test_dl_global,logger,lr = 0.01, device="cpu"):
avg_acc = 0.0
avg_kacc = 0.0
k_nets = []
for net_id, net in nets.items():
if net_id not in selected:
continue
dataidxs = net_dataidx_map[net_id]
noise_level = args.noise
if net_id == args.n_parties - 1:
noise_level = 0
if args.noise_type == 'space':
train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level, net_id, args.n_parties-1)
else:
noise_level = args.noise / (args.n_parties - 1) * net_id
train_dl_local, test_dl_local, _, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs, noise_level)
n_epoch = args.epochs
logger.info("Training network %s. n_training: %d" % (str(net_id), len(dataidxs)))
# move the model to cuda device:
net.to(device)
kd_agent = Distiller(lr=lr,epochs=n_epoch,device=device)
nets_cohort = []
optimizers = []
nets_cohort.append(net)
net_optimizer = optim.SGD(net.parameters(), lr)
optimizers.append(net_optimizer)
l_g_k = copy.deepcopy(g_k)
k_nets.append(l_g_k)
nets_cohort.append(l_g_k)
gk_optimizer = optim.SGD(l_g_k.parameters(), lr)
optimizers.append(gk_optimizer)
#### if use the local test_dl (also need modify get_dataloader function):
'''
nets_cohort, lr_ = kd_agent.mutual_kd(nets_cohort,train_dl_local, test_dl_local,optimizers,s_save_path=None)
acc =compute_accuracy(nets_cohort[0],test_dl_local,device=device)
print("net %d final test acc %f" % (net_id, acc))
logger.info("net %d final test acc %f" % (net_id, acc))
acc_k = compute_accuracy(nets_cohort[1],test_dl_local,device=device)
print("net %d knowledge network test acc %f" % (net_id, acc_k))
logger.info("net %d knowledge network test acc %f" % (net_id, acc_k))
'''
#### if use the global test_dl:
nets_cohort, lr_ = kd_agent.mutual_kd(nets_cohort, train_dl_local, test_dl_global, optimizers, s_save_path=None)
acc = compute_accuracy(nets_cohort[0], test_dl_global, device=device)
print("net %d final test acc %f" % (net_id, acc))
logger.info("net %d final test acc %f" % (net_id, acc))
acc_k = compute_accuracy(nets_cohort[1], test_dl_global, device=device)
print("net %d knowledge network test acc %f" % (net_id, acc_k))
logger.info("net %d knowledge network test acc %f" % (net_id, acc_k))
avg_acc += acc
avg_kacc += acc_k
avg_acc /= len(selected)
avg_kacc /= len(selected)
logger.info("avg test acc after local update %f" % avg_acc)
logger.info("avg knowledge test acc after local update %f" % avg_kacc)
nets_list = list(nets.values())
return nets_list,k_nets,lr_
# raise NotImplementedError
可以发现作者将本地模型和知识模型一起放进了nets_cohort列表中,传给Deep Mutual 模块进行学习,这里如下:
def train_students(self,epochs=20,plot_losses=True,save_model=True,save_model_path="./models/student.pth",):
for student in self.student_cohort:
student.train()
loss_arr = []
length_of_dataset = len(self.train_loader.dataset)
num_students = len(self.student_cohort )
print("\nTraining students...")
for ep in range(epochs):
epoch_loss = 0.0
correct = 0
for optimizer in self.student_optimizers:
lr = adjust_learning_rate(optimizer, ep, self.lr, epochs, lr_type='cos')
for (data, label) in self.train_loader:
data = data.to(self.device)
label = label.to(self.device)
for optim in self.student_optimizers:
optim.zero_grad()
avg_student_loss = 0
for i in range(num_students):
student_loss = 0
for j in range(num_students):
if i == j:
continue
student_loss += self.loss_fn(
self.student_cohort[i](data), self.student_cohort[j](data)
)
student_loss /= num_students - 1
student_loss += F.cross_entropy(self.student_cohort[i](data), label)
student_loss.backward()
self.student_optimizers[i].step()
avg_student_loss += student_loss
avg_student_loss /= num_students
predictions = []
correct_preds = []
for i, student in enumerate(self.student_cohort):
predictions.append(student(data).argmax(dim=1, keepdim=True))
correct_preds.append(
predictions[i].eq(label.view_as(predictions[i])).sum().item()
)
correct += max(correct_preds)
epoch_loss += avg_student_loss
epoch_acc = correct / length_of_dataset
if self.log:
self.writer.add_scalar("Training loss/Student", epoch_loss, epochs)
self.writer.add_scalar("Training accuracy/Student", epoch_acc, epochs)
loss_arr.append(epoch_loss)
print(f"Epoch: {ep + 1}, Loss: {epoch_loss}, Training Accuracy: {epoch_acc}")
return lr
这里的student_cohort就是刚刚传入的nets_cohort,显然student_number=2。这里两个网络的KL损失使用的是nn.MSELoss(),存在了self.loss_fn中,在上述代码也有体现。
服务器的代码如下:
def cloud_update(l_ks,g_k,train_loader,test_loader,lr, n_epoch,device):
ensemble = AvgEnsemble(l_ks)
# teh_optimizer = optim.SGD(ensemble.parameters(), lr)
stu_optimizer = optim.SGD(g_k.parameters(), lr)
kd_agent = Distiller(lr=lr,epochs=n_epoch,device=device)
_, _,_, lr = kd_agent.pure_kd(ensemble,g_k,None,stu_optimizer,train_loader,test_loader,device='cuda')
return lr
首先是Ensemble的提取:
class AvgEnsemble(nn.Module):
def __init__(self, net_list):
super(AvgEnsemble, self).__init__()
self.estimators = nn.ModuleList(net_list)
def forward(self, x):
outputs = [
F.softmax(estimator(x), dim=1) for estimator in self.estimators
]
proba = average(outputs)
return proba
可以发现就是把服务器样本依次输入每个客户端模型,然后取输出的平均。最后和客户端更新类似,如下:
def pure_kd(self, teacher, student, teh_optimizer, stu_optimizer, train_loader, test_loader,
device='cuda',s_save_path=None):
'''
Pure distillation. Distillate the teacher's knowledge to student.
Only student will be updated.
:param teacher_model (torch.nn.Module): Teacher model
:param student_model (torch.nn.Module): Student model
:param train_loader (torch.utils.data.DataLoader): Dataloader for training
:param val_loader (torch.utils.data.DataLoader): Dataloader for validation/testing
:param optimizer_teacher (torch.optim.*): Optimizer used for training teacher
:param optimizer_student (torch.optim.*): Optimizer used for training student
'''
if self.kd_type == 'VanillaKD':
kd = VanillaKD(teacher, student, train_loader, test_loader,
teh_optimizer, stu_optimizer,lr=self.lr ,device=device)
lr = kd.train_student(epochs=self.epochs, plot_losses=False, save_model=False)
stu_acc = kd.evaluate(teacher=False) # eval student
return kd.teacher_model, kd.student_model,stu_acc, lr