一. 介绍
联邦学习具有广阔的应用前景,但面临着来自数据异构的挑战,因为在现实世界中用户数据均为Non-IID分布的。在这样的情况下,传统的联邦学习算法可能会导致无法收敛到各个客户端的数据。
在本文中,我们提出了一个基于无数据的知识蒸馏算法——FEDGEN。具体来说,FEDGEN学习一个仅从用户模型的预测规则派生的生成模型,给定一个目标标签,可以生成与用户预测集合一致的特征表示。该生成器随后广播给用户,在潜在空间上护送他们的模型训练与增强样本,这体现了来自其他同行用户的蒸馏知识。给定一个维数远远小于输入空间的潜在空间,FEDGEN学习到的生成器可以是轻量级的,为当前FL框架引入最小的开销。
二. 问题定义
我们使用
X
⊂
R
p
\mathcal{X} \subset \mathbb{R}^p
X⊂Rp 表示为输入的实例空间
Z
⊂
R
d
\mathcal{Z} \subset \mathbb{R}^d
Z⊂Rd 表示为潜在的特征空间,其中
d
<
p
d<p
d<p,
Y
⊂
R
\mathcal{Y} \subset \mathbb{R}
Y⊂R为输出空间。
T
\mathcal{T}
T 表示为一个具体的域(domain),其由数据样本
X
\mathcal{X}
X 组成的数据分布
D
\mathcal{D}
D 和一个真实标签的函数:
c
∗
:
X
→
Y
c^*: \mathcal{X} \rightarrow \mathcal{Y}
c∗:X→Y组成,也就是
T
:
=
⟨
D
,
c
∗
⟩
\mathcal{T}:=\left\langle\mathcal{D}, c^*\right\rangle
T:=⟨D,c∗⟩。注意,在本文中将任务和域当作一样对待。模型参数
θ
:
=
[
θ
f
;
θ
p
]
\boldsymbol{\theta}:=\left[\boldsymbol{\theta}^f ; \boldsymbol{\theta}^p\right]
θ:=[θf;θp]由两个部分组成:一个特征提取器
f
:
X
→
Z
f: \mathcal{X} \rightarrow \mathcal{Z}
f:X→Z(对应参数为
θ
f
\boldsymbol{\theta}^f
θf), 一个预测器
h
:
Z
→
△
Y
h: \mathcal{Z} \rightarrow \triangle^{\mathcal{Y}}
h:Z→△Y (由参数
θ
p
\boldsymbol{\theta}^p
θp组成),其中
Δ
Y
\Delta^{\mathcal{Y}}
ΔY 表示单独的一个
Y
\mathcal{Y}
Y。给定一个凸的损失函数
l
:
△
Y
×
Y
→
R
l: \triangle^{\mathcal{Y}} \times \mathcal{Y} \rightarrow \mathbb{R}
l:△Y×Y→R, 模型参数
θ
\boldsymbol{\theta}
θ 在任务
T
\mathcal{T}
T上的损失表示为
L
T
(
θ
)
:
=
\mathcal{L}_{\mathcal{T}}(\boldsymbol{\theta}):=
LT(θ):=
E
x
∼
D
[
l
(
h
(
f
(
x
;
θ
f
)
;
θ
p
)
,
c
∗
(
x
)
)
]
\mathbb{E}_{x \sim \mathcal{D}}\left[l\left(h\left(f\left(x ; \boldsymbol{\theta}^f\right) ; \boldsymbol{\theta}^p\right), c^*(x)\right)\right]
Ex∼D[l(h(f(x;θf);θp),c∗(x))]。
联邦学习 致力于学习一个全局的模型参数
θ
\theta
θ,其能在每个客户端上都能达到最小的损失:
min
θ
E
T
k
∈
T
[
L
k
(
θ
)
]
(1)
\min _{\boldsymbol{\theta}} \mathbb{E}_{\mathcal{T}_k \in \mathcal{T}}\left[\mathcal{L}_k(\boldsymbol{\theta})\right] \tag1
θminETk∈T[Lk(θ)](1)
其中
T
=
{
T
k
}
k
=
1
K
\mathcal{T}=\left\{\mathcal{T}_k\right\}_{k=1}^K
T={Tk}k=1K表示所有的客户端的任务。 我们假设所有的任务共享相同的标签规则
c
∗
c^*
c∗以及损失函数
T
k
=
⟨
D
k
,
c
∗
⟩
\mathcal{T}_k=\left\langle\mathcal{D}_k, c^*\right\rangle
Tk=⟨Dk,c∗⟩。在实际中, 等式1可以进行这样的优化:
min
θ
1
K
∑
k
=
1
K
L
^
k
(
θ
)
\min _\theta \frac{1}{K} \sum_{k=1}^K \hat{\mathcal{L}}_k(\boldsymbol{\theta})
minθK1∑k=1KL^k(θ), 其中
L
^
k
(
θ
)
:
=
1
∣
D
^
k
∣
∑
x
i
∈
D
^
k
[
l
(
h
(
f
(
x
i
;
θ
f
)
;
θ
p
)
,
c
∗
(
x
i
)
)
]
\hat{\mathcal{L}}_k(\boldsymbol{\theta}):=\frac{1}{\left|\hat{\mathcal{D}}_k\right|} \sum_{x_i \in \hat{\mathcal{D}}_k}\left[l\left(h\left(f\left(x_i ; \boldsymbol{\theta}^f\right) ; \boldsymbol{\theta}^p\right), c^*\left(x_i\right)\right)\right]
L^k(θ):=∣D^k∣1∑xi∈D^k[l(h(f(xi;θf);θp),c∗(xi))]表示在数据集
D
^
k
\hat{\mathcal{D}}_k
D^k上的经验损失。 这里有个隐含的假设是:对于全局数据
D
^
\hat{\mathcal{D}}
D^是分布在所有的客户端上:
D
^
=
∪
{
D
^
k
}
k
=
1
K
\hat{\mathcal{D}}=\cup\left\{\hat{\mathcal{D}}_k\right\}_{k=1}^K
D^=∪{D^k}k=1K。
知识蒸馏(KD) 也被称为师生范式,其目标是学习轻量级的学生模型,使用从一个或多个强大的老师那里提取的知识。典型的KD利用一个代理数据集“pd”来最小化分别来自教师模型
θ
T
\theta^T
θT和学生模型
θ
S
\theta^S
θS的logits输出之间的差异。一个代表性的选择是使用Kullback-Leibler散度来衡量这种差异:
min
θ
S
E
x
∼
D
^
P
[
D
K
L
[
σ
(
g
(
f
(
x
;
θ
T
f
)
;
θ
T
p
)
∥
σ
(
g
(
f
(
x
;
θ
S
f
)
;
θ
S
p
)
]
]
\min _{\boldsymbol{\theta}_S} \mathbb{E}_{x \sim \hat{\mathcal{D}}_{\mathrm{P}}}\left[D _ { \mathrm { KL } } \left[\sigma\left(g\left(f\left(x ; \boldsymbol{\theta}_T^f\right) ; \boldsymbol{\theta}_T^p\right) \| \sigma\left(g\left(f\left(x ; \boldsymbol{\theta}_S^f\right) ; \boldsymbol{\theta}_S^p\right)\right]\right]\right.\right.
θSminEx∼D^P[DKL[σ(g(f(x;θTf);θTp)∥σ(g(f(x;θSf);θSp)]]
其中
g
(
⋅
)
g(\cdot)
g(⋅)表示为模型的逻辑输出,
σ
(
⋅
)
\sigma(\cdot)
σ(⋅)表示为激活函数。
KD的想法已经扩展到FL,以解决用户异质性,通过将每个用户模型
θ
k
\theta_k
θk视为教师,其信息聚合到学生(全局)模型
θ
\theta
θ,以提高其泛化性能:
min
θ
E
x
∼
D
^
P
[
D
K
L
[
σ
(
1
K
∑
k
=
1
K
g
(
f
(
x
;
θ
k
f
)
;
θ
k
p
)
)
∥
σ
(
g
(
f
(
x
;
θ
f
)
;
θ
p
)
]
]
\min _{\boldsymbol{\theta}} \mathbb{E}_{x \sim \hat{\mathcal{D}}_{\mathrm{P}}}\left[D_{\mathrm{KL}}\left[\sigma\left(\frac{1}{K} \sum_{k=1}^K g\left(f\left(x ; \boldsymbol{\theta}_k^f\right) ; \boldsymbol{\theta}_k^p\right)\right) \| \sigma\left(g\left(f\left(x ; \boldsymbol{\theta}^f\right) ; \boldsymbol{\theta}^p\right)\right]\right]\right.
θminEx∼D^P[DKL[σ(K1k=1∑Kg(f(x;θkf);θkp))∥σ(g(f(x;θf);θp)]]
上述方法的一个主要限制在于它依赖于代理数据集“
D
^
P
\hat{D}_P
D^P”,需要仔细考虑它的选择,并在蒸馏性能中起着关键作用。接下来,我们将展示如何以无数据的方式使KD对FL可行。
三. FEDGEN: 通过生成学习实现无数据的联邦蒸馏
本方法如图所示:
3.1 知识提取
我们的核心思想是提取关于数据分布的全局视图的知识,这些知识是传统FL无法观察到的,并将这些知识提取到局部模型中,以指导它们的学习。我们首先考虑学习一个条件分布
Q
∗
:
Y
→
X
Q^*: \mathcal{Y} \rightarrow \mathcal{X}
Q∗:Y→X来描述这种知识,它与真实数据分布一致:
Q
∗
=
arg
max
Q
:
Y
→
X
E
y
∼
p
(
y
)
E
x
∼
Q
(
x
∣
y
)
[
log
p
(
y
∣
x
)
]
,
(2)
Q^*=\underset{Q: \mathcal{Y} \rightarrow \mathcal{X}}{\arg \max } \mathbb{E}_{y \sim p(y)} \mathbb{E}_{x \sim Q(x \mid y)}[\log p(y \mid x)], \tag2
Q∗=Q:Y→XargmaxEy∼p(y)Ex∼Q(x∣y)[logp(y∣x)],(2)
其中
p
(
y
)
p(y)
p(y)表示标签的先验概率而
p
(
y
∣
x
)
p(y|x)
p(y∣x)表示为后验概率。为了能优化式2,我们替换了
p
(
y
)
p(y)
p(y)以及
p
(
y
∣
x
)
p(y|x)
p(y∣x)。首先,我们估计
p
(
y
)
p(y)
p(y)为:
p
^
(
y
)
∝
∑
k
E
x
∼
D
^
k
[
I
(
c
∗
(
x
)
=
y
)
]
,
\hat{p}(y) \propto \sum_k \mathbb{E}_{x \sim \hat{\mathcal{D}}_k}\left[\mathrm{I}\left(c^*(x)=y\right)\right],
p^(y)∝k∑Ex∼D^k[I(c∗(x)=y)],
其中
I
(
⋅
)
\mathrm{I}(\cdot)
I(⋅)为一个指标函数。在实际中,
p
^
(
y
)
\hat{p}(y)
p^(y)可以使用各个客户端训练标签的数量来进行统计。下一步,我们使用各个客户端的集成知识估计
p
(
y
∣
x
)
p(y|x)
p(y∣x):
log
p
^
(
y
∣
x
)
∝
1
K
∑
k
=
1
K
log
p
(
y
∣
x
;
θ
k
)
\log \hat{p}(y \mid x) \propto \frac{1}{K} \sum_{k=1}^K \log p\left(y \mid x ; \boldsymbol{\theta}_k\right)
logp^(y∣x)∝K1k=1∑Klogp(y∣x;θk)
有了上面的近似之后,直接在输入空间
X
\mathcal{X}
X上优化式子(2)依然是不行的,因为当
X
\mathcal{X}
X为高纬的时候会带来计算过载,还可能泄漏用户数据的信息。一个更好的方式是使用
G
∗
:
Y
→
Z
G^*: \mathcal{Y} \rightarrow \mathcal{Z}
G∗:Y→Z去作用于潜在的特征信息,从而避免相关的隐私暴露:
G
∗
=
arg
max
G
:
Y
→
Z
E
y
∼
p
^
(
y
)
E
z
∼
G
(
z
∣
y
)
[
∑
k
=
1
K
log
p
(
y
∣
z
;
θ
k
p
)
]
(3)
G^*=\underset{G: \mathcal{Y} \rightarrow \mathcal{Z}}{\arg \max } \mathbb{E}_{y \sim \hat{p}(y)} \mathbb{E}_{z \sim G(z \mid y)}\left[\sum_{k=1}^K \log p\left(y \mid z ; \boldsymbol{\theta}_k^p\right)\right] \tag3
G∗=G:Y→ZargmaxEy∼p^(y)Ez∼G(z∣y)[k=1∑Klogp(y∣z;θkp)](3)
根据上述推理,我们的目标是通过学习条件生成器
G
G
G进行知识提取,条件生成器
G
G
G由
w
w
w参数化,以优化以下目标:
min
w
J
(
w
)
:
=
E
y
∼
p
^
(
y
)
E
z
∼
G
w
(
z
∣
y
)
[
l
(
σ
(
1
K
∑
k
=
1
K
g
(
z
;
θ
k
p
)
)
,
y
)
]
(4)
\min _{\boldsymbol{w}} J(\boldsymbol{w}):=\mathbb{E}_{y \sim \hat{p}(y)} \mathbb{E}_{z \sim G_{\boldsymbol{w}}(z \mid y)}\left[l\left(\sigma\left(\frac{1}{K} \sum_{k=1}^K g\left(z ; \boldsymbol{\theta}_k^p\right)\right), y\right)\right] \tag4
wminJ(w):=Ey∼p^(y)Ez∼Gw(z∣y)[l(σ(K1k=1∑Kg(z;θkp)),y)](4)
其中
g
g
g和
σ
\sigma
σ表示为逻辑输出以及激活函数。这样,给定一系列样本标签,我们只需要使用用户的预测层的参数。具体来说,为了使样本更加多样化,我们创建了噪音向量:
ϵ
∼
N
(
0
,
I
)
\epsilon \sim \mathcal{N}(0, I)
ϵ∼N(0,I)。
3.2 知识提取
在特征提取之后,我们将学习到的生成器
G
w
G_w
Gw广播给本地用户,以便每个用户模型都可以从
G
w
G_w
Gw中采样,获得特征空间上的增强表示
z
∼
G
w
(
⋅
∣
y
)
z \sim G_w(\cdot \mid y)
z∼Gw(⋅∣y)。因此,局部模型
θ
k
\theta_k
θk的目标被改变,以使它对扩增样本产生理想预测的概率最大化:
min
θ
k
J
(
θ
k
)
:
=
L
^
k
(
θ
k
)
+
E
^
y
∼
p
^
(
y
)
,
z
∼
G
w
(
z
∣
y
)
[
l
(
h
(
z
;
θ
k
p
)
;
y
)
]
,
(5)
\min _{\boldsymbol{\theta}_k} J\left(\boldsymbol{\theta}_k\right):=\hat{\mathcal{L}}_k\left(\boldsymbol{\theta}_k\right)+\hat{\mathbb{E}}_{y \sim \hat{p}(y), z \sim G_{\boldsymbol{w}}(z \mid y)}\left[l\left(h\left(z ; \boldsymbol{\theta}_k^p\right) ; y\right)\right], \tag5
θkminJ(θk):=L^k(θk)+E^y∼p^(y),z∼Gw(z∣y)[l(h(z;θkp);y)],(5)
其中
L
^
k
(
θ
k
)
:
=
1
∣
D
^
k
∣
∑
x
i
∈
D
^
k
[
l
(
h
(
f
(
x
i
;
θ
k
f
)
;
θ
k
p
)
,
c
∗
(
x
i
)
)
]
\hat{\mathcal{L}}_k\left(\boldsymbol{\theta}_k\right):=\frac{1}{\left|\hat{\mathcal{D}}_k\right|} \sum_{x_i \in \hat{\mathcal{D}}_k}\left[l\left(h\left(f\left(x_i ; \boldsymbol{\theta}_k^f\right) ; \boldsymbol{\theta}_k^p\right), c^*\left(x_i\right)\right)\right]
L^k(θk):=∣D^k∣1∑xi∈D^k[l(h(f(xi;θkf);θkp),c∗(xi))]表示本地数据集上的相关损失。
相关算法如下:
四. 代码解析
代码链接点这里
我们首先来看一个客户端的训练:
def train(self, glob_iter, personalized=False, early_stop=100, regularization=True, verbose=False):
self.clean_up_counts()
self.model.train()
self.generative_model.eval()
TEACHER_LOSS, DIST_LOSS, LATENT_LOSS = 0, 0, 0
for epoch in range(self.local_epochs):
self.model.train()
for i in range(self.K):
self.optimizer.zero_grad()
#### sample from real dataset (un-weighted)
samples =self.get_next_train_batch(count_labels=True)
X, y = samples['X'], samples['y']
self.update_label_counts(samples['labels'], samples['counts'])
model_result=self.model(X, logit=True)
user_output_logp = model_result['output']
predictive_loss=self.loss(user_output_logp, y)
#### sample y and generate z
if regularization and epoch < early_stop:
generative_alpha=self.exp_lr_scheduler(glob_iter, decay=0.98, init_lr=self.generative_alpha)
generative_beta=self.exp_lr_scheduler(glob_iter, decay=0.98, init_lr=self.generative_beta)
### get generator output(latent representation) of the same label
gen_output=self.generative_model(y, latent_layer_idx=self.latent_layer_idx)['output']
logit_given_gen=self.model(gen_output, start_layer_idx=self.latent_layer_idx, logit=True)['logit']
target_p=F.softmax(logit_given_gen, dim=1).clone().detach()
user_latent_loss= generative_beta * self.ensemble_loss(user_output_logp, target_p)
sampled_y=np.random.choice(self.available_labels, self.gen_batch_size)
sampled_y=torch.tensor(sampled_y)
gen_result=self.generative_model(sampled_y, latent_layer_idx=self.latent_layer_idx)
gen_output=gen_result['output'] # latent representation when latent = True, x otherwise
user_output_logp =self.model(gen_output, start_layer_idx=self.latent_layer_idx)['output']
teacher_loss = generative_alpha * torch.mean(
self.generative_model.crossentropy_loss(user_output_logp, sampled_y)
)
# this is to further balance oversampled down-sampled synthetic data
gen_ratio = self.gen_batch_size / self.batch_size
loss=predictive_loss + gen_ratio * teacher_loss + user_latent_loss
TEACHER_LOSS+=teacher_loss
LATENT_LOSS+=user_latent_loss
else:
#### get loss and perform optimization
loss=predictive_loss
loss.backward()
self.optimizer.step()#self.local_model)
# local-model <=== self.model
self.clone_model_paramenter(self.model.parameters(), self.local_model)
if personalized:
self.clone_model_paramenter(self.model.parameters(), self.personalized_model_bar)
self.lr_scheduler.step(glob_iter)
if regularization and verbose:
TEACHER_LOSS=TEACHER_LOSS.detach().numpy() / (self.local_epochs * self.K)
LATENT_LOSS=LATENT_LOSS.detach().numpy() / (self.local_epochs * self.K)
info='\nUser Teacher Loss={:.4f}'.format(TEACHER_LOSS)
info+=', Latent Loss={:.4f}'.format(LATENT_LOSS)
print(info)
我这里给大家画了个损失的计算图:
大家可以对照着去这部分代码去看。其实很简单,一个表示拥护的预测损失,还有对应的是G产生样本再经过模型的预测层
θ
p
\theta^p
θp进行预测的潜在损失。最后就是针对为教师网络的对应损失,这里使用的是不重复标签(也就是一个类别的y只有一个)。
接下来我们看看服务器的更新:
def train_generator(self, batch_size, epoches=1, latent_layer_idx=-1, verbose=False):
"""
Learn a generator that find a consensus latent representation z, given a label 'y'.
:param batch_size:
:param epoches:
:param latent_layer_idx: if set to -1 (-2), get latent representation of the last (or 2nd to last) layer.
:param verbose: print loss information.
:return: Do not return anything.
"""
#self.generative_regularizer.train()
self.label_weights, self.qualified_labels = self.get_label_weights()
TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS, STUDENT_LOSS2 = 0, 0, 0, 0
def update_generator_(n_iters, student_model, TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS):
self.generative_model.train()
student_model.eval()
for i in range(n_iters):
self.generative_optimizer.zero_grad()
y=np.random.choice(self.qualified_labels, batch_size)
y_input=torch.LongTensor(y)
## feed to generator
gen_result=self.generative_model(y_input, latent_layer_idx=latent_layer_idx, verbose=True)
# get approximation of Z( latent) if latent set to True, X( raw image) otherwise
gen_output, eps=gen_result['output'], gen_result['eps']
##### get losses ####
# decoded = self.generative_regularizer(gen_output)
# regularization_loss = beta * self.generative_model.dist_loss(decoded, eps) # map generated z back to eps
diversity_loss=self.generative_model.diversity_loss(eps, gen_output) # encourage different outputs
######### get teacher loss ############
teacher_loss=0
teacher_logit=0
for user_idx, user in enumerate(self.selected_users):
user.model.eval()
weight=self.label_weights[y][:, user_idx].reshape(-1, 1)
expand_weight=np.tile(weight, (1, self.unique_labels))
user_result_given_gen=user.model(gen_output, start_layer_idx=latent_layer_idx, logit=True)
user_output_logp_=F.log_softmax(user_result_given_gen['logit'], dim=1)
teacher_loss_=torch.mean( \
self.generative_model.crossentropy_loss(user_output_logp_, y_input) * \
torch.tensor(weight, dtype=torch.float32))
teacher_loss+=teacher_loss_
teacher_logit+=user_result_given_gen['logit'] * torch.tensor(expand_weight, dtype=torch.float32)
######### get student loss ############
student_output=student_model(gen_output, start_layer_idx=latent_layer_idx, logit=True)
student_loss=F.kl_div(F.log_softmax(student_output['logit'], dim=1), F.softmax(teacher_logit, dim=1))
if self.ensemble_beta > 0:
loss=self.ensemble_alpha * teacher_loss - self.ensemble_beta * student_loss + self.ensemble_eta * diversity_loss
else:
loss=self.ensemble_alpha * teacher_loss + self.ensemble_eta * diversity_loss
loss.backward()
self.generative_optimizer.step()
TEACHER_LOSS += self.ensemble_alpha * teacher_loss#(torch.mean(TEACHER_LOSS.double())).item()
STUDENT_LOSS += self.ensemble_beta * student_loss#(torch.mean(student_loss.double())).item()
DIVERSITY_LOSS += self.ensemble_eta * diversity_loss#(torch.mean(diversity_loss.double())).item()
return TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS
for i in range(epoches):
TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS=update_generator_(
self.n_teacher_iters, self.model, TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS)
TEACHER_LOSS = TEACHER_LOSS.detach().numpy() / (self.n_teacher_iters * epoches)
STUDENT_LOSS = STUDENT_LOSS.detach().numpy() / (self.n_teacher_iters * epoches)
DIVERSITY_LOSS = DIVERSITY_LOSS.detach().numpy() / (self.n_teacher_iters * epoches)
info="Generator: Teacher Loss= {:.4f}, Student Loss= {:.4f}, Diversity Loss = {:.4f}, ". \
format(TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS)
if verbose:
print(info)
self.generative_lr_scheduler.step()
其中多样性损失没有写进去,大家根据代码找到即可。