论文1 :
Adapter-guided knowledge transfer for heterogeneous federated learning
- Shichong Liu, Haozhe Jin, Zhiwei Tang, Rui Zhai∗, Ke Lu, Junyang Yu, Chenxi Bai
- School of Software, Henan University, Kaifeng, 475000, China
- https://doi.org/10.1016/j.sysarc.2025.103338
- 《Journal of Systems Architecture》(CCF-B期刊)
Abstract
【背景】联邦学习(FL)旨在利用分散的数据协同训练一个全局模型或多个本地模型。
【动机】大多数现有的联邦学习方法侧重于解决客户端之间的统计异质性问题,却常常忽视了模型异质性带来的挑战。
【设计】为了解决统计异质性和模型异质性这两个问题,提出了FedAKT,这是一种新型的模型异质性个性化联邦学习(MHPFL)方法。
- 首先,为了促进跨客户端的知识转移,我们的方法为每个客户端添加了一个小型的同构适配器。
- 其次,我们引入了基于特征的互蒸馏(FMD)机制,该机制能够促进本地模型之间的双向知识交换。
- 第三,我们提出了头部两用(HDU)机制,使每个本地模型的头部能够从不同的视角有效地学习特征信息。
【实验】在CIFAR10、CIFAR - 100和Tiny - ImageNet数据集上进行的大量实验表明,与先进的基线方法相比,FedAKT具有优越性。
Background
- Federated Learning
- Non-IID data(statistical heterogeneity)
- 客户端倾向于设计独特的模型,以避免资源浪费和低性能客户端过载。
Motivations
- 模型异质性和统计异质性问题通常由模型异质个性化联邦学习(MHPFL)来解决。现有的 MHPFL 方法可分为三类:知识蒸馏、模型解耦和类别信息共享。
- 知识蒸馏方法通常依赖公共数据集,获取这些数据集或保证其质量可能具有挑战性。一些方法使用生成模型来生成高质量数据集,但训练生成模型会带来巨大的计算开销。
- 模型解耦方法需要将相同的网络层上传到服务器,这限制了模型的异质性程度。
- 类别信息共享方法,如共享原型或对数几率,可以降低通信成本。然而,它们可能存在泄露敏感数据分布信息的风险,从而限制了其应用。
Challenges
- 在回顾先前的模型异质个性化联邦学习(MHPFL)研究时,有三个核心问题成为了主要的研究重点:
- 确定客户端和服务器之间应该共享哪些知识;
- 开发何种框架以促进异构模型设置下的跨客户端知识转移;
- 实现本地模型之间的有效知识交换。
Overview
我们的FedAKT工作流程如图4所示。为简化图中符号,
F
i
F_{i}
Fi 代表本地异构特征提取器,
H
i
H_{i}
Hi 表示本地异构头部,
i
∈
[
1
,
m
]
i \in[1, m]
i∈[1,m]。此外,
F
i
′
F_{i}'
Fi′ 是本地同构适配器,
F
′
F'
F′ 表示全局同构适配器。
h
i
h_{i}
hi 和
h
i
′
h_{i}'
hi′ 分别表示来自本地异构特征提取器
F
i
F_{i}
Fi 和同构适配器
F
i
′
F_{i}'
Fi′ 的隐藏特征。
F
i
F_{i}
Fi 和
F
i
′
F_{i}'
Fi′ 将输入空间映射到相同的特征空间
R
e
\mathbb{R}^{e}
Re,而
H
i
H_{i}
Hi 将特征空间映射到类别空间
R
c
\mathbb{R}^{c}
Rc。
我们的算法过程分为四个阶段:
- 在初始化阶段,服务器上随机初始化全局同构适配器。此外,每个客户端上随机初始化本地同构适配器、本地异构特征提取器和本地异构头部。
- 在本地训练阶段,本地数据同时输入到同构适配器和异构特征提取器中,以获得相应的隐藏特征。然后,每个客户端执行基于特征的互蒸馏来交换知识。同时,这些隐藏特征被输入到本地异构头部以输出预测结果。最后,计算预测类别与真实标签类别之间的分类损失。
- 在本地更新阶段,每个本地模型(包括异构特征提取器和异构头部)使用梯度下降法进行更新。此外,本地同构适配器也进行更新。
- 在通信阶段,服务器负责聚合本地同构适配器,并广播全局同构适配器。每个客户端接收全局同构适配器,并准备开始新的一轮训练迭代。重复该算法过程,直到客户端模型收敛。
Designs
对于问题1(Q1):受先前研究成果的启发,我们利用这样一个见解,即深度神经网络(DNN)的较低层往往比高层学习到更通用的信息。因此,我们将每个本地模型拆分为一个异构特征提取器(较低层)和一个异构头部(较高层),发现异构特征提取器中的知识更有利于共享。
对于问题2(Q2):由于模型的异构性,模型无法在服务器上直接进行聚合。目前大多数MHPFL方法要么需要复杂的流程,要么知识共享能力有限。因此,我们在每个客户端模型中引入一个小型同构适配器(小型同构特征提取器)。这些同构适配器被上传到服务器,有助于实现跨客户端的有效知识转移。
对于问题3(Q3):受互蒸馏和特征蒸馏的启发,我们提出了一种创新的基于特征的互蒸馏(FMD)机制,支持同构适配器和异构特征提取器之间的双向知识转移。此外,我们引入了头部两用(HDU)机制,使同构适配器和异构特征提取器能够共享客户端独特的本地异构头部。通过HDU,本地异构头部可以从不同角度有效地整合和学习特征信息。
4.2 基于特征的互蒸馏机制 Feature-based mutual distillation mechanism
为了在同构适配器和异构特征提取器之间实现有效的知识交换,我们提出了基于特征的互蒸馏(FMD)机制。在每一轮通信中,客户端
i
i
i 同时训练同构适配器
F
i
′
(
δ
i
)
F_{i}'(\delta_{i})
Fi′(δi)、异构特征提取器
F
i
(
φ
i
)
F_{i}(\varphi_{i})
Fi(φi) 和异构头部
H
i
(
θ
i
)
H_{i}(\theta_{i})
Hi(θi)。具体来说,在公式(8)中,本地数据样本
x
j
∈
B
k
x_{j} \in B_{k}
xj∈Bk 同时输入到
F
i
(
φ
i
)
F_{i}(\varphi_{i})
Fi(φi) 和
F
i
′
(
δ
i
)
F_{i}'(\delta_{i})
Fi′(δi) 中,得到隐藏特征
h
j
∈
R
e
h_{j} \in \mathbb{R}^{e}
hj∈Re 和
h
j
′
∈
R
e
h_{j}' \in \mathbb{R}^{e}
hj′∈Re。这里,
B
k
B_{k}
Bk 代表本地数据
D
i
D_{i}
Di 的一个批次,
j
j
j 表示
B
k
B_{k}
Bk 的第
j
j
j 个样本。由于初始条件不同,
F
i
(
φ
i
)
F_{i}(\varphi_{i})
Fi(φi) 和
F
i
′
(
δ
i
)
F_{i}'(\delta_{i})
Fi′(δi) 可以学习到不同的特征:
h
j
=
F
i
(
φ
i
;
x
j
)
,
h
j
′
=
F
i
′
(
δ
i
;
x
j
)
.
(
8
)
h_{j}=\mathcal{F}_{i}\left(\varphi_{i} ; x_{j}\right), h_{j}'=\mathcal{F}_{i}'\left(\delta_{i} ; x_{j}\right). (8)
hj=Fi(φi;xj),hj′=Fi′(δi;xj).(8)
与传统的使用两个完整模型的软预测进行互蒸馏的方法不同,我们的方法利用包含更多可学习信息的高维隐藏特征。此外,本地维护的
F
i
(
φ
i
)
F_{i}(\varphi_{i})
Fi(φi) 具有更多本地知识,而可共享的
F
i
′
(
δ
i
)
F_{i}'(\delta_{i})
Fi′(δi) 具有更多全局知识。因此,
h
j
h_{j}
hj 包含更多个性化特征信息,
h
j
′
h_{j}'
hj′ 包含更多通用特征信息。我们的方法使用均方误差
L
M
S
E
L_{MSE}
LMSE 来衡量
h
j
h_{j}
hj 和
h
j
′
h_{j}'
hj′ 之间的距离,并将其作为互蒸馏损失。
F
i
(
φ
i
)
F_{i}(\varphi_{i})
Fi(φi) 和
F
i
′
(
δ
i
)
F_{i}'(\delta_{i})
Fi′(δi) 可以通过最小化该损失来学习对方所缺乏的知识。基于特征的互蒸馏损失函数由公式(9)和(10)表示。这里,
L
D
1
L_{D_{1}}
LD1 是异构特征提取器的互蒸馏损失,
L
D
2
L_{D_{2}}
LD2 是同构适配器的互蒸馏损失:
L
D
1
=
∑
j
=
0
∣
B
k
∣
−
1
L
M
S
E
(
h
j
,
h
j
′
)
,
(
9
)
\mathcal{L}_{D_{1}}=\sum_{j=0}^{\left|B_{k}\right|-1} \mathcal{L}_{M S E}\left(h_{j}, h_{j}'\right), (9)
LD1=j=0∑∣Bk∣−1LMSE(hj,hj′),(9)
L
D
2
=
∑
j
=
0
∣
B
k
∣
−
1
L
M
S
E
(
h
j
′
,
h
j
)
.
(
10
)
\mathcal{L}_{D_{2}}=\sum_{j=0}^{\left|B_{k}\right|-1} \mathcal{L}_{M S E}\left(h_{j}', h_{j}\right). (10)
LD2=j=0∑∣Bk∣−1LMSE(hj′,hj).(10)
4.3 头部两用机制 Header dual-use mechanism
由于同构适配器
F
i
′
(
δ
i
)
F_{i}'(\delta_{i})
Fi′(δi) 不是一个完整的网络,无法通过监督学习进行训练。因此,我们提出了头部两用(HDU)机制:同构适配器
F
i
′
(
δ
i
)
F_{i}'(\delta_{i})
Fi′(δi) 和异构特征提取器
F
i
(
φ
i
)
F_{i}(\varphi_{i})
Fi(φi) 共享一个独特的本地异构头部
H
i
(
θ
i
)
H_{i}(\theta_{i})
Hi(θi)。具体来说,在公式(11)中,
h
j
h_{j}
hj 和
h
j
′
h_{j}'
hj′ 都输入到
H
i
(
θ
i
)
H_{i}(\theta_{i})
Hi(θi) 中进行训练,
H
i
(
θ
i
)
H_{i}(\theta_{i})
Hi(θi) 输出相应的软预测
p
j
p_{j}
pj 和
p
j
′
p_{j}'
pj′。HDU机制允许同构适配器与本地异构头部连接,即
F
i
′
(
δ
i
)
∘
H
i
(
θ
i
)
F_{i}'(\delta_{i}) \circ H_{i}(\theta_{i})
Fi′(δi)∘Hi(θi),在每个客户端上形成一个完整的网络。同时,本地异构头部
H
i
(
θ
i
)
H_{i}(\theta_{i})
Hi(θi) 可以有效地吸收来自
F
i
(
φ
i
)
F_{i}(\varphi_{i})
Fi(φi) 和
F
i
′
(
δ
i
)
F_{i}'(\delta_{i})
Fi′(δi) 输出的特征信息:
p
j
=
H
i
(
θ
i
;
h
j
)
,
p
j
′
=
H
i
(
θ
i
;
h
j
′
)
.
(
11
)
p_{j}=\mathcal{H}_{i}\left(\theta_{i} ; h_{j}\right), p_{j}'=\mathcal{H}_{i}\left(\theta_{i} ; h_{j}'\right). (11)
pj=Hi(θi;hj),pj′=Hi(θi;hj′).(11)
在监督学习的指导下,同构适配器和异构特征提取器不会相互误导,而是不断提高性能。随后,在公式(12)和(13)中,输出
p
j
p_{j}
pj 和
p
j
′
p_{j}'
pj′ 分别用于与真实标签
y
y
y 计算监督学习损失。我们使用交叉熵损失
L
C
E
L_{CE}
LCE 作为监督学习损失:
L
P
1
=
∑
j
=
0
∣
B
k
∣
−
1
L
C
E
(
p
j
,
y
j
)
,
(
12
)
\mathcal{L}_{P_{1}}=\sum_{j=0}^{\left|B_{k}\right|-1} \mathcal{L}_{C E}\left(p_{j}, y_{j}\right), (12)
LP1=j=0∑∣Bk∣−1LCE(pj,yj),(12)
L
P
2
=
∑
j
=
0
∣
B
k
∣
−
1
L
C
E
(
p
j
′
,
y
j
)
.
(
13
)
\mathcal{L}_{P_{2}}=\sum_{j=0}^{\left|B_{k}\right|-1} \mathcal{L}_{C E}\left(p_{j}', y_{j}\right). (13)
LP2=j=0∑∣Bk∣−1LCE(pj′,yj).(13)
4.4 模型更新和推理 Model update and inference
在模型更新之前,我们计算总监督学习损失
L
P
L_{P}
LP、异构特征提取器的总训练损失
C
T
1
C_{T_{1}}
CT1 和同构适配器的总训练损失
C
T
2
C_{T_{2}}
CT2。公式如(14)所示,其中
λ
\lambda
λ 是一个超参数,用于控制互蒸馏的强度:
L
P
=
L
P
1
+
L
P
2
,
L
T
1
=
L
P
1
+
λ
⋅
L
D
1
,
L
T
2
=
L
P
2
+
λ
⋅
L
D
2
.
\begin{aligned} &\mathcal{L}_{P}=\mathcal{L}_{P_{1}}+\mathcal{L}_{P_{2}}, \\ &\mathcal{L}_{T_{1}}=\mathcal{L}_{P_{1}}+\lambda \cdot \mathcal{L}_{D_{1}}, \\ &\mathcal{L}_{T_{2}}=\mathcal{L}_{P_{2}}+\lambda \cdot \mathcal{L}_{D_{2}}. \end{aligned}
LP=LP1+LP2,LT1=LP1+λ⋅LD1,LT2=LP2+λ⋅LD2.
计算完损失后,通过梯度下降法更新异构头部参数
θ
i
\theta_{i}
θi、异构特征提取器参数
φ
i
\varphi_{i}
φi 和同构适配器参数
δ
i
\delta_{i}
δi。公式如(15)所示:
θ
i
t
←
θ
i
t
−
1
−
η
∇
L
P
,
φ
i
t
←
φ
i
t
−
1
−
η
∇
L
T
1
,
δ
i
t
←
δ
i
t
−
1
−
η
∇
L
T
2
.
\begin{aligned} &\theta_{i}^{t} \leftarrow \theta_{i}^{t-1}-\eta \nabla \mathcal{L}_{P}, \\ &\varphi_{i}^{t} \leftarrow \varphi_{i}^{t-1}-\eta \nabla \mathcal{L}_{T_{1}}, \\ &\delta_{i}^{t} \leftarrow \delta_{i}^{t-1}-\eta \nabla \mathcal{L}_{T_{2}}. \end{aligned}
θit←θit−1−η∇LP,φit←φit−1−η∇LT1,δit←δit−1−η∇LT2.
在通信阶段,客户端将其同构适配器上传到服务器。服务器根据每个客户端的数据容量聚合这些同构适配器的参数。这确保了数据量更多的客户端对聚合后的同构适配器提供更可靠的更新。由于同构适配器的架构趋向于模型的较低层,它们包含更多可以在客户端之间交换的通用信息[8,25,29]。因此,同构适配器的聚合受异构数据的影响较小。聚合方法可以形式化为公式(16):
δ
=
∑
i
=
0
M
−
1
∣
D
i
∣
N
δ
i
.
\delta=\sum_{i=0}^{M-1} \frac{\left|D_{i}\right|}{N} \delta_{i}.
δ=i=0∑M−1N∣Di∣δi.
我们的最终目标是为每个客户端训练一个更好的本地模型。同构适配器在每个客户端的训练阶段仅作为辅助模块。经过多轮联邦训练,本地异构模型在同构适配器的帮助下学习并平衡全局和本地知识。因此,在模型推理阶段,我们丢弃所有客户端的同构适配器以降低计算成本。相反,我们仅使用本地异构模型(包括异构特征提取器和头部)进行推理。
Evaluations
5.1 设置
数据集:我们在三个流行的图像分类数据集上评估我们的方法和基线方法,包括CIFAR-10、CIFAR100和Tiny-ImageNet。
- CIFAR-10和CIFAR100都包含60,000张彩色图像,每张图像的尺寸为32×32像素。
- CIFAR-10包含10个不同的类别,每个类别有6000张图像。
- CIFAR-100包含100个不同的类别,每个类别有600张图像。
- Tiny-ImageNet由100,000张彩色图像组成,每张图像尺寸为64×64像素,它包含200个类别,每个类别有500张图像。
基线方法:我们将我们的方法与以下7种基线方法进行比较。
- Standalone(独立训练)使每个模型在每个客户端上独立训练,不与服务器进行任何通信。
- LG-FedAvg拆分每个本地模型,允许较低层(异构特征提取器)具有不同的架构,同时与服务器共享相同的较高层(同构头部)。
- FedGen在服务器端通过无数据方法训练一个生成器,并将集成信息广播给客户端。
- FML为每个客户端设置两个完整的模型(本地导师模型和全局学生模型),通过知识蒸馏促进模型训练。
- FedKD为每个客户端设置两个完整的模型(本地导师模型和全局学生模型),通过知识蒸馏促进模型训练。
- FedDistill共享类别平均logits,并通过知识蒸馏交换知识。
- FedProto通过本地和全局原型对本地模型的训练进行正则化。
统计异质性:我们在两种常用的场景中实现统计异质性设置。
- 第一种是病理设置 pathological setting,我们从Cifar10/Cifar100/Tiny-ImageNet中为每个客户端分配2/10/20个类别。
- 第二种是实际设置 practical setting,在实验中我们使用狄利克雷分布(表示为 D i r ( β ) Dir(\beta) Dir(β))对这三个数据集进行划分。随着 β \beta β值减小,统计异质性的强度会增加。
- 按照病理和实际设置,我们将每个客户端的数据分为训练集(75%)和测试集(25%)。此外,我们在图5中展示了病理和实际数据分布(包括训练集和测试集)的可视化结果。
模型异质性:通过调整模型宽度获得异构CNN模型组。在这个模型组中,卷积核的数量和全连接层的维度不同。
-
表2展示了五种异构模型架构,第 ( i m o d 5 ) (i \mod 5) (imod5)个架构分配给客户端 i i i。我们将前五层(即Conv1、Maxpool1、Conv2、Maxpool2和FC1)视为异构特征提取器,其中FC1的输出维度设置为1000用于蒸馏。最后两层(即FC2和FC3)被视为异构头部。FC3的输出维度与CIFAR-10、CIFAR-100或Tiny-ImageNet数据集中的类别数量一致。
-
通过调整模型深度设计了异构ResNet模型组。该模型组包括五个异构模型:ResNet-4、ResNet-6、ResNet-8、ResNet-10和ResNet-18。我们用AdaptiveAvgPool1d层替换ResNet模型的最后一个全连接层。为了模拟异构头部,我们在AdaptiveAvgPool1d层之后添加表2中的头部。AdaptiveAvgPool1d层的输出维度设置为1000。
-
对于FedGen和LG-FedAvg,它们依赖同构头部。因此,我们将它们的FC2输出维度更改为500。
-
对于FML和FedKD,它们需要在每个客户端上训练两个模型(导师模型和学生模型)。因此,我们选择CNN-4或ResNet-4作为它们的同构学生模型。
同构适配器:对于我们的FedAKT,选择合适的适配器至关重要。为了最小化计算和通信开销,同构适配器需要设计得尽可能小。我们的方法选择最小客户端模型的特征提取器作为同构适配器。因此,对于表2中的CNN模型组,我们选择CNN-4的特征提取器(前五层)作为同构适配器。同样,对于ResNet模型组,我们选择ResNet-4的特征提取器作为同构适配器。此外,当客户端有足够的资源时,可以采用更大的同构适配器。总体而言,同构适配器的选择不是固定的。
超参数设置:我们在PFLlib平台上进行实验。除非另有说明,我们使用以下实验设置。
- 我们设置100次全局通信轮次,客户端参与率 ρ = 1 \rho = 1 ρ=1。
- 遵循FedAvg,我们设置批量大小 B = 10 B = 10 B=10,学习率 η = 0.01 \eta = 0.01 η=0.01。
- 在每一轮通信中,我们在每个客户端上训练一个本地epoch。
- 对于我们的FedAKT,我们通过设置超参数 λ \lambda λ来控制同构适配器和异构特征提取器之间的知识交换程度。对于CIFAR-10和CIFAR-100数据集,我们将 λ \lambda λ设置为3。对于Tiny-ImageNet数据集,我们将 λ \lambda λ设置为0.5。此外,对于基线方法中的特定超参数,我们遵循每个方法原始论文中指定的实验设置。
- Standalone和LG-FedAvg没有额外的超参数。
- 对于FedGen,我们将其生成器学习率设置为0.1,服务器学习轮数设置为100,噪声维度设置为32。
- 对于FML,我们将知识蒸馏超参数 α = 0.5 \alpha = 0.5 α=0.5和 β = 0.5 \beta = 0.5 β=0.5。
- 对于FedKD,我们将其动态梯度近似策略阈值 T s t a r t = 0.95 T_{start} = 0.95 Tstart=0.95和 T e n d = 0.95 T_{end} = 0.95 Tend=0.95。此外,我们将学生模型学习率 η = 0.01 \eta = 0.01 η=0.01,这与FML和FedKD中的本地导师模型学习率相同。
- 对于FedDistill,我们将知识蒸馏超参数 λ = 1 \lambda = 1 λ=1。
- 对于FedProto,我们将其正则化项系数 λ = 0.1 \lambda = 0.1 λ=0.1。
5.2 结果与分析
我们在一台拥有32G内存、i9 - 13900K处理器、一块NVIDIA 4090 GPU和Ubuntu 20.04.6 LTS系统的机器上运行所有实验。我们使用每个客户端的测试集来评估其本地模型的准确性。然后,通过计算所有客户端模型的平均测试准确率来报告所有算法的最佳结果。
有效性:为了证明我们的FedAKT的有效性,我们在CIFAR-10、CIFAR-100和Tiny-ImageNet数据集上进行实验。
-
表3展示了CNN模型组的平均测试准确率,其中客户端数量为20。
-
图6展示了每一轮通信的平均测试准确率。实验结果表明,我们的方法比这些基线方法具有更好的性能。对于这三个数据集,在病理设置中,我们的方法比最佳基线方法的准确率最高可提高0.84%、1.47%和0.85%。此外,在实际设置( β = 0.1 \beta = 0.1 β=0.1)中,我们的方法比最佳基线方法的准确率最高可提高0.64%、1.32%和2.00%。
-
表4展示了在实际设置( β = 0.5 \beta = 0.5 β=0.5)下三个数据集的平均测试准确率。我们发现,在数据异质性较低的环境中,所有方法的性能都会略有下降。这是因为MHPFL方法在高度异质的环境中更注重个性化。
-
如表5所示,我们还在Tiny-ImageNet数据集上使用ResNet模型组进行实验。我们方法的平均测试准确率高于大多数基线方法。
可扩展性:为了证明我们的方法在不同客户端数量下的可扩展性,我们在CIFAR-100数据集上设置10、50、100和200个客户端,并使用CNN模型组进行实验。
-
图7展示了我们的方法与所有基线方法的平均测试准确率对比。在50个客户端的设置下,与最佳基线方法相比,我们的方法在病理设置中的准确率提高了3.42%,在实际设置中的准确率提高了1.59%。对于100个客户端,我们的方法在病理设置和实际设置中分别比最佳基线方法的准确率最高可提高3.02%和2.31%。同样,在200个客户端的情况下,与最佳基线方法相比,我们的方法在病理设置中的准确率提高了2.51%,在实际设置中的准确率提高了3.20%。由于总数据量是恒定的,客户端数量的增加会导致每个客户端分配到的数据更少。更分散的数据会给联邦训练带来更大的挑战,但与这些基线方法相比,我们算法的平均测试准确率仍然最高。
-
此外,我们将本地训练轮次更改为5,并在20个客户端上进行实验,在图8中报告平均测试准确率。
-
另外,如果客户端有足够的系统资源,可以选择更大本地模型的特征提取器作为同构适配器。因此,我们进行了进一步的实验,并在图9中报告结果。对于CNN模型组,我们为每个客户端选择CNN-5的特征提取器作为同构适配器。然后将其与原始的同构适配器(使用CNN-4的特征提取器)进行比较。同样,对于ResNet模型组,我们为每个客户端选择ResNet-6的特征提取器作为同构适配器,并与原始的同构适配器(使用ResNet-4的特征提取器)进行比较。CNN模型组在CIFAR-100数据集上进行训练,而ResNet模型组在Tiny-ImageNet数据集上进行训练。结果进一步证明了我们的方法具有出色的可扩展性。
方法分析:
- 对于Standalone,每个客户端仅使用自己的数据训练本地模型,不进行跨客户端的知识交换。
- 对于LG-FedAvg,高层模型(同构头部)的聚合很容易受到异构数据的影响。
- 对于FedGen,服务器上的轻量级生成器可能难以生成对客户端有用的信息,特别是在高度异构的环境中。
- 对于FML,软预测的知识对于相互学习来说是有限的。
- 对于FedKD,共享压缩后的学生模型可能会导致一些有价值的信息丢失。
- 对于FedDistill和FedProto,可共享的类别平均信息可能不足。
- 然而,我们的方法避免了这些问题,并实现了更高的平均测试准确率。
计算成本:
- 在表6中,我们通过计算在Cifar100上100轮通信的总训练时间(使用CNN模型组)来报告计算成本。结果表明,我们的方法优于FedGen、FML和FedKD,因为这些方法依赖于训练辅助生成器或学生模型。然而,我们的FedAKT的计算成本高于LG-FedAvg、FedDistill和FedProto。因为它们不需要训练额外的组件,而我们的方法需要训练同构适配器,这会带来额外的计算负担。此外,我们在推理时仅使用没有同构适配器的本地模型。模型推理阶段的计算成本与所有基线方法相同。总体而言,为了换取模型性能而产生少量额外的计算成本是可以接受的。
通信成本:
- 在表6中,我们还通过分析理论和实际参数数量,报告了在CIFAR-100上一次迭代的通信成本(使用CNN模型组)。 θ i \theta_{i} θi表示LG-FedAvg和FedGen中同构头部的参数。 S g S_{g} Sg表示FedGen中生成器的参数。 φ g \varphi_{g} φg和 θ g \theta_{g} θg分别表示FedKD和FML中同构特征提取器和头部的参数。 r r r是FedKD中的模型压缩率。 c c c表示所有客户端上的类别数量。 C i C_{i} Ci是客户端 i i i上的类别数量。 K K K是隐藏特征的维度。“ M M M”表示百万。与这些基线方法相比,我们的方法取得了中等的结果。FedGen、FML和FedKD的通信成本高于我们的方法。具体来说,在FML和FedKD中,传输学生模型导致了最高的通信成本。在FedGen中,从服务器发送生成器会增加额外的通信成本。然而,我们的方法比FedDistill和FedProto的通信成本更高。这是因为我们的FedAKT在客户端和服务器之间交换适配器,而这两个基线方法仅交换轻量级的类别平均信息。
隐私保护:我们的方法将本地数据和异构模型保留在每个客户端上。仅在服务器和客户端之间交换小型同构适配器,通过这些小型同构适配器很难推断出本地数据。因此,本地数据和模型架构的隐私得到了保护。
5.3 消融研究
为了研究每个机制对我们的FedAKT性能的影响,如图10所示,我们在CIFAR-10和CIFAR-100数据集(使用CNN模型组)上进行了以下四种情况的消融实验:
- 保留我们的FedAKT的完整方法。
- 去除头部两用机制(HDU),保留基于特征的互蒸馏机制(FMD)。
- 去除基于特征的互蒸馏机制(FMD),保留头部两用机制(HDU)。
- 去除同构适配器(算法退化为独立训练,客户端仅训练其本地异构模型)。
关于消融实验,我们将客户端数量设置为20,数据分布设置为病理设置。消融实验的平均测试准确率如表7所示。
- 将情况(1)与情况(4)进行比较,准确率分别提高了1.24%和5.04%。
- 将情况(2)与情况(4)进行比较,有基于特征的互蒸馏机制的准确率比没有该机制时分别提高了0.8%和3.28%。
- 将情况(3)与情况(4)进行比较,有头部共享机制的准确率比没有该机制时分别提高了0.28%和0.51%。同时去除FMD和HDU导致的平均测试准确率比仅去除其中一个机制更低,这意味着这两个机制可以相互增强。
5.4 讨论
我们通过在多个数据集上进行大量实验,与基线方法进行了对比分析。同时,我们通过进一步的消融研究证明了每个模块的有效性。结果表明,FedAKT提高了本地模型的准确率,在各种统计和模型异质性设置中表现良好。它具有灵活性和可扩展性,能够适应不同数量的客户端参与联邦训练。FedAKT中的同构适配器作为跨客户端知识转移和本地模型知识交换的桥梁。然而,仍然存在局限性。具体来说,由于引入了同构适配器,与最佳基线方法相比,我们的方法增加了计算和通信成本。此外,我们的工作假设客户端拥有异构模型架构并执行基本的图像分类任务。在实际应用中,我们需要进一步考虑为不同客户端定制何种模型架构。而且,如何解决多样化的任务以满足客户端的个性化需求也是一个挑战。
Conclusion
在这项工作中,我们提出了一种名为FedAKT的新型模型异构个性化联邦学习(MHPFL)方法,旨在应对统计异质性和模型异质性带来的挑战。我们的方法为每个客户端添加了一个小型的同构适配器,这些同构适配器会上传至服务器进行聚合,以获取全局知识,而异构特征提取器和头部则保留在本地,以保存本地知识。我们引入了基于特征的互蒸馏机制,以实现同构适配器与异构特征提取器之间的有效知识交换。此外,我们采用了头部共享机制,以捕获更多的特征信息。我们的方法在不依赖公共数据集或复杂生成模型的情况下,有效地提升了本地模型的性能。大量实验表明,我们的FedAKT在三个图像分类数据集上取得了更高的准确率。
在未来的工作中,我们将从模型准确性、理论分析、算法成本和实际应用这四个方面来改进我们的方法。
- 同构适配器在我们的FedAKT中起着关键作用。因此,设计出更具通用性和高效性的适配器能够进一步提高本地模型的准确性。
- 我们已经通过大量实验进行了对比分析,但完整的理论分析并非易事,这将留待未来的工作去完成。
- 联邦学习广泛应用于资源受限的场景中。所以,探索进一步提高通信效率并降低计算成本的方法是很有意义的。
- 为特定的客户端和任务部署合理的模型架构,这也是我们未来需要研究的工作。