一. 介绍
随着数据的爆炸式增长和严格的隐私保护政策,由于高昂的带宽成本和隐私泄露的风险,不计后果的数据传输和聚合逐渐变得不可接受。最近,联邦学习(FL)被提出来取代传统的高度集中的学习范式,并保护数据隐私。FL的主要挑战之一是数据的异构性,即客户端的数据是不相同的和独立分布的(Non-IID)。经验证,在这种情况下,简单的FL算法FedAvg会导致局部模型漂移,灾难性地忘记全局知识,从而进一步导致性能下降和收敛缓慢。这是因为局部模型仅用局部数据更新,即最小化局部经验损失。然而,在Non-IID的FL中,最小化局部经验损失与最小化全局经验损失从根本上是不一致的。
本文提出了一种新的方法,通过无数据知识蒸馏(FedFTG)动态微调全局模型来提高标准FL的性能,该方法同时改进了模型聚合过程并利用了服务器的丰富计算能力。具体来说,FedFTG通过服务器中的辅助生成器对局部模型的输入空间进行建模,然后生成伪数据,将局部模型中的知识转移到全局模型中,从而提高性能。为了在整个训练过程中进行有效的知识蒸馏,FedFTG迭代探索数据分布中的硬样本,这将导致局部模型和全局模型之间的预测不一致。下图是FedFTG和FedAVG的比较。FedFTG使用硬样本对全局模型进行微调,以纠正模型聚合后的模型偏移。生成器和全局模型以无数据的方式进行对抗训练,因此整个过程不会违反FL中的隐私策略。考虑到数据异构场景下的标签分布转移,我们进一步提出定制标签采样和类级集成技术,探索客户端的分布相关性,最大限度地利用知识。
二. 方法
2.1 基于硬样本挖掘的无数据知识蒸馏全局模型优化
我们使用
w
w
w表示为服务器和客户端的模型参数。在本文中,我们考虑存在
K
K
K个客户端,其中
D
k
=
{
(
x
k
,
i
,
y
k
,
i
)
}
i
=
1
N
k
\mathcal{D}_k=\left\{\left(x_{k, i}, y_{k, i}\right)\right\}_{i=1}^{N_k}
Dk={(xk,i,yk,i)}i=1Nk为存储在第
k
k
k个客户端上独立的数据集,
N
k
N_k
Nk表示为样本的数量。一般来说,联邦学习需要解决下面的问题:
min
ω
1
K
∑
k
=
1
K
f
k
(
ω
)
,
f
k
(
ω
)
=
1
N
k
∑
i
=
1
N
k
L
(
x
k
i
,
y
k
i
;
ω
)
(1)
\min _\omega \frac{1}{K} \sum_{k=1}^K f_k(\omega), f_k(\omega)=\frac{1}{N_k} \sum_{i=1}^{N_k} \mathcal{L}\left(x_k^i, y_k^i ; \omega\right) \tag{1}
ωminK1k=1∑Kfk(ω),fk(ω)=Nk1i=1∑NkL(xki,yki;ω)(1)
其中
L
\mathcal{L}
L表示为损失函数,每个数据集
D
k
D_k
Dk都是分布不均匀的。由于FL中的隐私保护约束,服务器不能直接访问客户端的本地数据。为了解决式(1),在每个通信轮
t
t
t中,现有的方法讲全局模型
w
w
w发送到随机的客户端集合
S
t
S_t
St,并且去优化
min
w
f
k
(
w
)
,
k
∈
S
t
\min_w f_k(w),k\in S_t
minwfk(w),k∈St。服务器搜集这些本地的模型
{
w
k
}
k
∈
S
t
\{w_k\}_{k\in S_t }
{wk}k∈St,并且聚合梯度去更新全局的模型
w
w
w。然而,在数据异构的场景下,局部模型之间存在着很大的偏差。这就导致聚合后的模型效果反而低于聚合前的模型。为了解决这个问题,我们提出了一种无数据的知识蒸馏方法来对全局模型进行微调,使全局模型能够保留局部模型中的知识并尽可能地保持它们的性能。具体来说,服务器维护一个条件生成器
G
G
G,它生成伪数据以捕获客户端的数据分布,如下所示:
x
~
=
G
(
z
,
y
;
θ
)
(2)
\widetilde{x}=G(z, y ; \theta) \tag{2}
x
=G(z,y;θ)(2)
其中
θ
\theta
θ表示
G
G
G的参数,
z
∼
N
(
0
,
1
)
z \sim \mathcal{N}(\mathbf{0}, \mathbf{1})
z∼N(0,1)为一个标准的Gaussian 噪音,
y
y
y表示为从已知分布
p
t
(
y
)
p_t(y)
pt(y)中进行采样的
x
~
\widetilde{x}
x
对应的类标签。
如上图展示的,我们将伪数据输入到全局模型中去解决下面的问题:
min
ω
E
z
∼
N
(
0
,
1
)
y
∼
p
t
(
y
)
[
L
m
d
]
=
min
ω
E
z
∼
N
(
0
,
1
)
y
∼
p
t
(
y
)
[
∑
k
∈
S
t
α
t
k
,
y
L
m
d
k
]
\min _\omega \mathbb{E}_{\substack{z \sim \mathcal{N}(0,1) \\ y \sim p_t(y)}}\left[\mathcal{L}_{m d}\right]=\min _\omega \mathbb{E}_{\substack{z \sim \mathcal{N}(0,1) \\ y \sim p_t(y)}}\left[\sum_{k \in S_t} \alpha_t^{k, y} \mathcal{L}_{m d}^k\right]
ωminEz∼N(0,1)y∼pt(y)[Lmd]=ωminEz∼N(0,1)y∼pt(y)[k∈St∑αtk,yLmdk]
其中
L
m
d
k
\mathcal{L}_{m d}^k
Lmdk伪全局模型和本地模型的模型差异度函数:
L
m
d
k
=
D
K
L
(
σ
(
D
(
x
~
;
ω
)
)
∥
σ
(
D
(
x
~
;
ω
k
)
)
)
,
(3)
\mathcal{L}_{m d}^k=D_{K L}\left(\sigma(D(\widetilde{x} ; \omega)) \| \sigma\left(D\left(\widetilde{x} ; \omega_k\right)\right)\right), \tag{3}
Lmdk=DKL(σ(D(x
;ω))∥σ(D(x
;ωk))),(3)
其中
D
D
D表示为一个分类器,
σ
\sigma
σ为一个softmax函数,
D
K
L
D_{KL}
DKL表示为KL散度,
α
t
k
,
y
\alpha^{k,y}_t
αtk,y表示为集成过程中控制来自不同局部模型的知识权重。
数据的真实度和多样性约束: 为了更好地从局部模型中提取知识,伪数据需要和本地模型的输入保持一致。因此,我们使用语义损失
L
c
l
s
\mathcal{L}_{cls}
Lcls去训练
G
G
G,如下:
min
θ
E
z
∼
N
(
0
,
1
)
y
∼
p
t
(
y
)
[
L
c
l
s
]
=
min
θ
E
z
∼
N
(
0
,
1
)
y
∼
p
t
(
y
)
[
∑
k
∈
S
t
α
t
k
,
y
L
c
l
s
k
]
(5)
\min _\theta \mathbb{E}_{\substack{z \sim \mathcal{N}(\mathbf{0}, 1) \\ y \sim p_t(y)}}\left[\mathcal{L}_{c l s}\right]=\min _\theta \mathbb{E}_{\substack{z \sim \mathcal{N}(\mathbf{0}, 1) \\ y \sim p_t(y)}}\left[\sum_{k \in S_t} \alpha_t^{k, y} \mathcal{L}_{c l s}^k\right] \tag{5}
θminEz∼N(0,1)y∼pt(y)[Lcls]=θminEz∼N(0,1)y∼pt(y)[k∈St∑αtk,yLclsk](5)
其中
L
c
l
s
k
\mathcal{L}_{cls}^k
Lclsk为交叉熵损失:
L
c
l
s
k
=
L
C
E
(
σ
(
D
(
x
~
;
ω
k
)
)
,
y
)
,
(6)
\mathcal{L}_{c l s}^k=\mathcal{L}_{C E}\left(\sigma\left(D\left(\tilde{x} ; \omega_k\right)\right), y\right), \tag{6}
Lclsk=LCE(σ(D(x~;ωk)),y),(6)
这样就能保证
x
~
\widetilde{x}
x
能够达到一个高的预测。
简单地使用
L
c
l
s
\mathcal{L}_{cls}
Lcls将会导致模型崩塌:
G
G
G将会输出相同的数据对于每个类来说。为了解决这个问题,我们中使用多样性损失
L
d
i
s
\mathcal{L}_{dis}
Ldis来提高生成数据的多样性,
L
dis
=
e
1
Q
∗
Q
∑
i
,
j
∈
{
1
,
…
,
Q
}
(
−
∥
x
~
i
−
x
~
j
∥
2
∗
∥
z
i
−
z
j
∥
2
)
(7)
\mathcal{L}_{\text {dis }}=e^{\frac{1}{Q*Q} \sum_{i, j \in\{1, \ldots, Q\}}\left(-\left\|\widetilde{x}_i-\widetilde{x}_j\right\|_2 *\left\|z_i-z_j\right\|_2\right)} \tag{7}
Ldis =eQ∗Q1∑i,j∈{1,…,Q}(−∥x
i−x
j∥2∗∥zi−zj∥2)(7)
硬样本挖掘: 使用
L
c
l
s
\mathcal{L}_{cls}
Lcls训练生成器
G
G
G将生成具有低分类错误的伪数据
x
~
\widetilde{x}
x
,这意味着
x
~
\widetilde{x}
x
包含类
y
y
y中最有区别的特征,并且易于分类。然而,这些原始样本不会导致全局模型和局部模型之间的预测不一致,即
L
m
d
=
0
\mathcal{L}_{md}=0
Lmd=0,因此在训练期间不会优化全局模型。为了有效利用局部模型中的知识并将其传递到全局模型,我们探索了数据分布中导致局部模型和全局模型之间预测不一致的硬样本。具体地说,我们用
L
m
d
\mathcal{L}_{md}
Lmd对抗性地训练生成器和全局模型:(1)生成器被强制生成最大化
L
m
d
\mathcal{L}_{md}
Lmd的硬样本,(2)全局模型被训练为使用硬样本最小化
L
m
d
\mathcal{L}_{md}
Lmd。因此,可以逐步微调全局模型,以适应图3所示的数据分布:
min
ω
max
θ
E
z
∼
N
(
0
,
1
)
,
y
∼
p
t
(
y
)
[
L
m
d
−
λ
cls
L
cls
−
λ
dis
L
dis
]
(8)
\min _\omega \max _\theta \mathbb{E}_{z \sim \mathcal{N}(0,1), y \sim p_t(y)}\left[\mathcal{L}_{m d}-\lambda_{\text {cls }} \mathcal{L}_{\text {cls }}-\lambda_{\text {dis }} \mathcal{L}_{\text {dis }}\right] \tag{8}
ωminθmaxEz∼N(0,1),y∼pt(y)[Lmd−λcls Lcls −λdis Ldis ](8)
2.2 适应标签分布转移的有效知识蒸馏
在数据异构的场景下,数据之间的类是不平衡的,并且对于一个类别,知识的重要性在不同的本地模型中也是不同的。
自定义标签样本: 通常,在数据异构场景中,本地客户端的数据集是类不平衡的,甚至有些类没有数据。已经证明深度神经网络倾向于学习多数类而忽略少数类。因此,局部模型中少数类的数据信息可能是错误的,具有误导性,生成的伪数据对于度量模型差异无效。如果对类标签
y
y
y进行统一采样,这些无效数据会影响全局模型的训练,导致性能下降。为了缓解这一问题,我们根据每一轮训练数据的整体分布自定义采样概率
p
t
(
y
)
p_t(y)
pt(y),从而生成更多具有有效信息的伪数据,
p
t
(
y
)
∝
∑
k
∈
S
t
∑
i
=
1
N
k
E
(
x
k
i
,
y
k
i
)
∼
D
k
[
1
y
i
=
y
]
=
∑
k
∈
S
t
n
k
y
,
(9)
p_t(y) \propto \sum_{k \in S_t} \sum_{i=1}^{N_k} \mathbb{E}_{\left(x_k^i, y_k^i\right) \sim \mathcal{D}_k}\left[1_{y_i=y}\right]=\sum_{k \in S_t} n_k^y, \tag{9}
pt(y)∝k∈St∑i=1∑NkE(xki,yki)∼Dk[1yi=y]=k∈St∑nky,(9)
其中
1
c
o
n
d
i
t
i
o
n
1_{condition}
1condition表示如果条件成立那么为1否则为0,
n
k
y
n^y_k
nky表示为类
y
y
y的数量在客户端
k
k
k中。由式(9)可知,大多数类的伪数据生成概率较高,因此FedFTG可以保证在数据异构场景下有效的知识蒸馏。
类集成: 由于标签分布的变化,对于一个类别,知识的重要性在本地模型中是不同的。如果把同样的权重分配给客户,就不能很好地理解和利用重要的知识。因此,我们提出类级集成,根据客户端个体数据分布分配集成权重。因此,我们计算
α
t
k
,
y
\alpha^{k,y}_t
αtk,y根据类
y
y
y在客户端
k
k
k中以及全局中的比值进行计算:
α
t
k
,
y
=
n
k
y
/
∑
i
∈
S
t
n
i
y
\alpha_t^{k, y}=n_k^y / \sum_{i \in S_t} n_i^y
αtk,y=nky/i∈St∑niy
因此,可以根据客户端的知识在类中的重要性,灵活地将来自客户端的知识进行集成,使FedFTG能够最大限度地利用来自本地模型的知识。下面为算法:
三. 总结
本文其实是利用Data-Free的知识蒸馏到联邦学习中,Data-Free旨在利用数据生成的方式(GAN)去不断地训练学生网络,而本文则是将微调的方式去处理聚合的过程。利用生成的样本不断地训练全局模型,同时再利用设定的各种损失去训练样本生成器,达到一种对抗的程度,从而获得更好的效果。代码目前还没找到,找到了我会更新上来