论文 Confidence-Aware Personalized Federated Learning via Variational Expectation Maximization(2023 CVPR)的阅读笔记,可在 arXiv 在线阅读
1. Abstract
本文针对 FL
中各个客户的数据集非独立同分布以及数据集大小不同的场景。pFedVEM
是一个基于层次贝叶斯模型和变分推断的 PFL
框架。其中每个客户有一个本地模型,所有客户共享一个全局模型,本地模型和全局模型采用同样的模型架构。全局模型被建模为一个潜变量,作用是增强客户本地模型参数的联合分布,以及捕获不同客户的共同趋势。
算法给出了一个置信度值的估计,用于提供每个客户在聚合阶段的权重参数。
2. Symbol Table
符号 | 含义 |
---|---|
D j \mathcal{D}_j Dj | 第 j j j 个客户的数据集 |
n j n_j nj | 第 j j j 个客户数据中实例数量 |
x i ( j ) \boldsymbol{x}_i^{\left(j\right)} xi(j) | 第 j j j 个客户数据集中的第 i i i 条实例 |
y i ( j ) y_i^{\left(j\right)} yi(j) | 第 j j j 个客户数据集中的第 i i i 条实例的标签(单标签) |
w j \boldsymbol{w}_j wj | 第 j j j 个客户的模型权重 |
J J J | 客户数量 |
3. Problem Formulation
在本文的设置中,一共有 J J J 个客户, J J J 个本地模型和 1 1 1 个全局模型,所有的模型采用相同的模型架构(区别只是模型参数不同)。全局模型不用于预测,它只作为辅助的潜变量优化局部模型。每个局部模型各自在自己的数据集上进行预测。
在 PFL
中,要解决的是如下优化问题,也就是要获得使得总的平均损失最小的本地模型集合
min
w
1
:
J
∈
R
d
f
(
w
1
:
J
;
D
)
:
=
1
J
∑
j
=
1
J
f
j
(
w
j
;
D
j
)
\underset{\boldsymbol{w}_{1:J}\in\mathbb{R}^d}{\min}f\left(\boldsymbol{w}_{1:J};\mathcal{D}\right):=\frac{1}{J}\sum\limits_{j=1}^Jf_j\left(\boldsymbol{w}_j;\mathcal{D}_j\right)
w1:J∈Rdminf(w1:J;D):=J1j=1∑Jfj(wj;Dj)
f
j
(
w
j
;
D
j
)
f_j\left(\boldsymbol{w}_j;\mathcal{D}_j\right)
fj(wj;Dj) 是第
j
j
j 个客户的损失
f
j
(
w
j
;
D
j
)
=
1
n
j
∑
i
=
1
n
j
l
(
x
i
(
j
)
,
y
i
(
j
)
;
w
)
f_j\left(\boldsymbol{w}_j;\mathcal{D}_j\right)=\frac{1}{n_j}\sum\limits_{i=1}^{n_j}l\left(\boldsymbol{x}_i^{\left(j\right)}, y_i^{\left(j\right)};\boldsymbol{w}\right)
fj(wj;Dj)=nj1i=1∑njl(xi(j),yi(j);w)
l
(
⋅
,
⋅
;
w
)
l\left(\cdot,\cdot;\boldsymbol{w}\right)
l(⋅,⋅;w) 是当
w
j
=
w
\boldsymbol{w}_j=\boldsymbol{w}
wj=w 时,模型在数据点
(
⋅
,
⋅
)
\left(\cdot, \cdot\right)
(⋅,⋅) 上的损失。
4. Confidence-aware PFL
本文为 PFL
提出了一个通用的贝叶斯框架,然后基于变分期望最大化来进行模型优化。达到的效果是每个客户可以为自己的本地训练结果估计一个置信度分数,该分数被用在模型聚合过程中自适应的调教权重(全局模型由本地模型加权聚合得到,做的事情就是如何自适应的得到这个权重)
4.1 Hierarchical Bayesian modeling
从贝叶斯的角度来说,PFL
要做的事情就是为任意客户
j
j
j 获取后验概率密度函数
p
(
w
j
∣
D
j
)
p\left(\boldsymbol{w}_j\mid\mathcal{D}_j\right)
p(wj∣Dj)。由于在本文的设置中,所有的客户都运行相似的任务,因此对一个客户而言,其余客户的数据集也应该有助于构建该客户的后验。但是,每个客户不能访问其余客户的隐私数据集,因此,通过构建一个潜变量全局模型来捕获不同客户之间的相关性,每个客户可借助该全局模型从其他客户那里获得帮助,这就是本文模型的核心思想。
用
w
\boldsymbol{w}
w 代表全局模型的模型参数,本文假设在不同的客户之间有条件独立性如下
p
(
w
i
∣
w
)
p
(
w
j
∣
w
)
=
p
(
w
i
,
w
j
∣
w
)
p\left(\boldsymbol{w}_i\mid\boldsymbol{w}\right)p\left(\boldsymbol{w}_j\mid\boldsymbol{w}\right) = p\left(\boldsymbol{w}_i, \boldsymbol{w}_j\mid\boldsymbol{w}\right)
p(wi∣w)p(wj∣w)=p(wi,wj∣w)
p
(
w
j
∣
w
)
p\left(\boldsymbol{w}_j\mid\boldsymbol{w}\right)
p(wj∣w) 的含义就是在基于所有客户数据集的共有趋势的前提下,对第
j
j
j 个客户的模型进行定制。由于添加了全局模型
w
\boldsymbol{w}
w,因此,要求解的最佳模型参数集扩展为
{
w
,
w
1
:
J
}
\left\{\boldsymbol{w}, \boldsymbol{w}_{1:J}\right\}
{w,w1:J},因此,要做的就是获得后验
p
(
w
,
w
1
:
J
∣
D
)
p\left(\boldsymbol{w}, \boldsymbol{w}_{1:J}\mid\mathcal{D}\right)
p(w,w1:J∣D)
根据贝叶斯规则有
p
(
w
,
w
1
:
J
∣
D
)
=
p
(
D
∣
w
,
w
1
:
J
)
p
(
w
,
w
1
:
J
)
p
(
D
)
p\left(\boldsymbol{w}, \boldsymbol{w}_{1:J}\mid\mathcal{D}\right)=\frac{p\left(\mathcal{D}\mid\boldsymbol{w}, \boldsymbol{w}_{1:J}\right)p\left(\boldsymbol{w}, \boldsymbol{w}_{1:J}\right)}{p\left(\mathcal{D}\right)}
p(w,w1:J∣D)=p(D)p(D∣w,w1:J)p(w,w1:J)
其正比于等式右方的分子部分,又有
p
(
D
∣
w
j
)
=
exp
(
−
n
j
f
j
(
w
j
;
D
j
)
)
p
(
w
,
w
1
:
J
)
=
p
(
w
)
∏
j
=
1
J
p
(
w
j
∣
w
)
\begin{aligned} p\left(\mathcal{D}\mid\boldsymbol{w}_j\right) &= \exp\left(-n_jf_j\left(\boldsymbol{w}_j; \mathcal{D}_j\right)\right) \\ p\left(\boldsymbol{w}, \boldsymbol{w}_{1:J}\right) &= p\left(\boldsymbol{w}\right)\prod\limits_{j=1}^Jp\left(\boldsymbol{w}_j\mid\boldsymbol{w}\right) \end{aligned}
p(D∣wj)p(w,w1:J)=exp(−njfj(wj;Dj))=p(w)j=1∏Jp(wj∣w)
其中第一个等式为似然函数定义的展开(有一点小变换,相当于先对似然函数取对数将连乘变成连加,然后又进行指数操作,还原为似然函数),第二个等式因为前面的条件独立性假设而成立
因此,模型转变为
p
(
w
)
∏
j
=
1
J
p
(
w
j
∣
w
)
exp
(
−
n
j
f
j
(
w
j
;
D
j
)
)
p\left(\boldsymbol{w}\right)\prod\limits_{j=1}^Jp\left(\boldsymbol{w}_j\mid\boldsymbol{w}\right)\exp\left(-n_jf_j\left(\boldsymbol{w}_j; \mathcal{D}_j\right)\right)
p(w)j=1∏Jp(wj∣w)exp(−njfj(wj;Dj))
这表明边际联合分布
p
(
w
1
:
J
∣
D
)
=
∫
p
(
w
,
w
1
:
J
∣
D
)
d
w
p\left(\boldsymbol{w}_{1:J}\mid\mathcal{D}\right)=\int p\left(\boldsymbol{w}, \boldsymbol{w}_{1:J}\mid\mathcal{D}\right)d\boldsymbol{w}
p(w1:J∣D)=∫p(w,w1:J∣D)dw 在引入了全局模型作为潜变量后变得更加灵活,能够允许客户之间进行复杂的通信。
4.2 Variational expectation maximization
本文选择先验(推理问题不关注先验,选一个即可)
p
(
w
j
∣
w
)
p\left(\boldsymbol{w}_j\mid\boldsymbol{w}\right)
p(wj∣w) 为一个各项同性高斯条件分布,即
p
(
w
j
∣
w
)
=
N
(
w
j
∣
w
,
ρ
j
2
I
)
p\left(\boldsymbol{w}_j\mid\boldsymbol{w}\right) = \mathcal{N}\left(\boldsymbol{w}_j\mid\boldsymbol{w},\rho_j^2\boldsymbol{I}\right)
p(wj∣w)=N(wj∣w,ρj2I)
其中
ρ
j
2
\rho_j^2
ρj2 代表分布的方差
4.3 Maximizing the marginal likelihood
优化贝叶斯模型的一种方法是使用 MAP
,但是选择的先验给出了一组超参
ρ
1
:
J
\rho_{1:J}
ρ1:J,这很难选择,因此,本文采用的方法是用分布
q
(
w
1
:
J
)
q\left(\boldsymbol{w}_{1:J}\right)
q(w1:J) 来近似后验分布
p
(
w
1
:
J
∣
D
)
p\left(\boldsymbol{w}_{1:J}\mid\mathcal{D}\right)
p(w1:J∣D),假设各客户独立,分布
q
(
w
1
:
J
)
q\left(\boldsymbol{w}_{1:J}\right)
q(w1:J) 被定义为一组分布的乘积
q
(
w
1
:
J
)
:
=
∏
j
=
1
J
q
j
(
w
j
)
q\left(\boldsymbol{w}_{1:J}\right) := \prod\limits_{j=1}^Jq_j\left(\boldsymbol{w}_j\right)
q(w1:J):=j=1∏Jqj(wj)
分布
q
j
(
w
j
)
q_j\left(\boldsymbol{w}_j\right)
qj(wj) 被选择为多元高斯分布,即
q
j
(
w
j
)
=
N
(
w
j
∣
μ
j
,
Σ
j
)
q_j\left(\boldsymbol{w}_j\right)=\mathcal{N}\left(\boldsymbol{w_j}\mid\boldsymbol{\mu}_j, \boldsymbol{\Sigma}_j\right)
qj(wj)=N(wj∣μj,Σj)
KL
散度用于衡量两个分布的相似程度,因此,需要的最佳
q
(
w
1
:
J
)
q\left(\boldsymbol{w}_{1:J}\right)
q(w1:J) 由下式得到
q
∗
(
w
1
:
J
)
=
arg
min
q
(
w
1
:
J
)
KL
(
q
(
w
1
:
J
)
∥
p
(
w
1
:
J
∣
D
)
)
q^*\left(\boldsymbol{w}_{1:J}\right)=\underset{q\left(\boldsymbol{w}_{1:J}\right)}{\arg\min}\operatorname{KL}\left(q\left(\boldsymbol{w}_{1:J}\right)\Vert p\left(\boldsymbol{w}_{1:J}\mid\mathcal{D}\right)\right)
q∗(w1:J)=q(w1:J)argminKL(q(w1:J)∥p(w1:J∣D))
又有 KL
散度展开有
KL
(
q
(
w
1
:
J
)
∥
p
(
w
1
:
J
∣
D
)
)
=
log
p
(
D
;
w
1
:
J
)
−
ELBO
(
q
(
w
1
:
J
)
;
ρ
1
:
J
2
,
w
)
\operatorname{KL}\left(q\left(\boldsymbol{w}_{1:J}\right)\Vert p\left(\boldsymbol{w}_{1:J}\mid\mathcal{D}\right)\right) = \log p\left(\mathcal{D};\boldsymbol{w}_{1:J}\right) - \operatorname{ELBO}\left(q\left(\boldsymbol{w}_{1:J}\right);\rho_{1:J}^2,\boldsymbol{w}\right)
KL(q(w1:J)∥p(w1:J∣D))=logp(D;w1:J)−ELBO(q(w1:J);ρ1:J2,w)
又由于 KL
散度的非负性,因此有
log
p
(
D
;
w
1
:
J
)
>
ELBO
(
q
(
w
1
:
J
)
;
ρ
1
:
J
2
,
w
)
\log p\left(\mathcal{D};\boldsymbol{w}_{1:J}\right) > \operatorname{ELBO}\left(q\left(\boldsymbol{w}_{1:J}\right);\rho_{1:J}^2,\boldsymbol{w}\right)
logp(D;w1:J)>ELBO(q(w1:J);ρ1:J2,w)
因此,最优化问题转为极大化证据下界 ELBO
,将其展开如下
ELBO
(
q
(
w
1
:
J
)
;
ρ
1
:
J
2
,
w
)
=
∑
j
=
1
J
(
E
q
(
w
j
)
[
log
p
(
D
j
∣
w
j
)
]
−
KL
[
q
(
w
j
∥
p
(
w
j
∣
w
,
ρ
j
2
)
)
]
)
=
∑
j
=
1
J
(
E
q
(
w
j
)
[
log
p
(
D
j
∣
w
j
)
]
−
E
q
(
w
j
)
[
log
q
(
w
j
)
]
+
E
q
(
w
j
)
[
log
p
(
w
j
,
w
,
ρ
j
2
)
]
)
−
∑
j
=
1
J
log
p
(
w
,
ρ
j
2
)
=
∑
j
=
1
J
(
E
q
(
w
j
)
[
log
p
(
D
j
,
w
j
)
−
log
p
(
w
j
)
−
log
q
(
w
j
)
+
log
p
(
w
j
,
w
,
ρ
j
2
)
−
log
p
(
w
,
ρ
j
2
)
]
)
=
∑
j
=
1
J
(
E
q
(
w
j
)
[
log
p
(
D
j
,
w
j
)
−
log
q
(
w
j
)
]
)
∝
∑
j
=
1
J
(
E
q
(
w
j
)
[
log
p
(
D
j
∣
w
j
,
w
)
+
log
p
(
w
j
∣
w
,
ρ
j
2
)
]
)
∝
∑
j
=
1
J
E
q
(
w
j
)
[
log
p
(
w
j
∣
w
,
ρ
j
2
)
]
\begin{aligned} \operatorname{ELBO}\left(q\left(\boldsymbol{w}_{1:J}\right);\rho_{1:J}^2,\boldsymbol{w}\right) &= \sum\limits_{j=1}^J\left(\mathbb{E}_{q\left(\boldsymbol{w}_j\right)}\left[\log p\left(\mathcal{D}_j\mid\boldsymbol{w}_j\right)\right] - \operatorname{KL}\left[q\left(\boldsymbol{w}_j\Vert p\left(\boldsymbol{w}_j\mid\boldsymbol{w}, \rho_j^2\right)\right)\right]\right) \\ &= \sum\limits_{j=1}^J\left(\mathbb{E}_{q\left(\boldsymbol{w}_j\right)}\left[\log p\left(\mathcal{D}_j\mid\boldsymbol{w}_j\right)\right] - \mathbb{E}_{q\left(\boldsymbol{w}_j\right)}\left[\log q\left(\boldsymbol{w}_j\right)\right] + \mathbb{E}_{q\left(\boldsymbol{w}_j\right)}\left[\log p\left(\boldsymbol{w}_j, \boldsymbol{w}, \rho_j^2\right) \right]\right) - \sum\limits_{j=1}^J\log p\left(\boldsymbol{w}, \rho_j^2\right)\\ &= \sum\limits_{j=1}^J\left(\mathbb{E}_{q\left(\boldsymbol{w}_j\right)}\left[\log p\left(\mathcal{D}_j, \boldsymbol{w}_j\right) - \log p\left(\boldsymbol{w}_j\right) - \log q\left(\boldsymbol{w}_j\right) + \log p\left(\boldsymbol{w}_j, \boldsymbol{w}, \rho_j^2\right) - \log p\left(\boldsymbol{w}, \rho_j^2\right)\right]\right)\\ &= \sum\limits_{j=1}^J\left(\mathbb{E}_{q\left(\boldsymbol{w}_j\right)}\left[\log p\left(\mathcal{D}_j, \boldsymbol{w}_j\right) - \log q\left(\boldsymbol{w}_j\right)\right]\right) \\ &\propto \sum\limits_{j=1}^J\left(\mathbb{E}_{q\left(\boldsymbol{w}_j\right)}\left[\log p\left(\mathcal{D}_j\mid\boldsymbol{w}_j, \boldsymbol{w}\right) + \log p\left(\boldsymbol{w}_j\mid\boldsymbol{w},\rho_j^2\right)\right]\right) \\ &\propto \sum\limits_{j=1}^J\mathbb{E}_{q\left(\boldsymbol{w}_j\right)}\left[\log p\left(\boldsymbol{w}_j\mid\boldsymbol{w},\rho_j^2\right)\right] \end{aligned}
ELBO(q(w1:J);ρ1:J2,w)=j=1∑J(Eq(wj)[logp(Dj∣wj)]−KL[q(wj∥p(wj∣w,ρj2))])=j=1∑J(Eq(wj)[logp(Dj∣wj)]−Eq(wj)[logq(wj)]+Eq(wj)[logp(wj,w,ρj2)])−j=1∑Jlogp(w,ρj2)=j=1∑J(Eq(wj)[logp(Dj,wj)−logp(wj)−logq(wj)+logp(wj,w,ρj2)−logp(w,ρj2)])=j=1∑J(Eq(wj)[logp(Dj,wj)−logq(wj)])∝j=1∑J(Eq(wj)[logp(Dj∣wj,w)+logp(wj∣w,ρj2)])∝j=1∑JEq(wj)[logp(wj∣w,ρj2)]
可以得到闭式解,同时定义每个客户的置信度分数为
τ
j
:
=
1
/
ρ
j
2
\tau_j:=1/\rho_j^2
τj:=1/ρj2,有
τ
j
=
d
/
(
Tr
(
Σ
j
)
+
∥
μ
j
−
w
∥
2
)
\tau_j = d/\left(\operatorname{Tr}\left(\boldsymbol{\Sigma_j}\right) + \left\Vert\boldsymbol{\mu}_j - \boldsymbol{w}\right\Vert^2\right)
τj=d/(Tr(Σj)+
μj−w
2)
其中
d
d
d 是
w
\boldsymbol{w}
w 的维度,同时定义服务器聚合权重的公式如下
w
∗
=
∑
j
=
1
J
τ
j
μ
j
∑
j
=
1
J
τ
j
\boldsymbol{w}^* = \frac{\sum\limits_{j=1}^J\tau_j\boldsymbol{\mu}_j}{\sum\limits_{j=1}^J\tau_j}
w∗=j=1∑Jτjj=1∑Jτjμj
直观来说,客户的不确定性越小,模型偏差越小,置信度分数就越高,置信度分数越高的客户,在模型聚合时的权重越大。
5. Algorithm Implementation
为了能够用梯度下降进行训练,需要对分布进行重参数化并采样,具体来说,用
(
μ
1
:
J
,
π
1
:
J
)
\left(\boldsymbol{\mu}_{1:J}, \boldsymbol{\pi}_{1:J}\right)
(μ1:J,π1:J) 参数化高斯分布,并以如下形式从中采样
q
(
w
j
)
=
μ
j
+
diag
(
log
(
1
+
exp
(
π
j
)
)
)
⋅
N
(
0
,
I
d
)
q\left(\boldsymbol{w}_j\right) = \boldsymbol{\mu}_j + \operatorname{diag}\left(\log\left(1+\exp\left(\boldsymbol{\pi}_j\right)\right)\right)\cdot\mathcal{N}\left(0,\boldsymbol{I}_d\right)
q(wj)=μj+diag(log(1+exp(πj)))⋅N(0,Id)
此外,使用蒙特卡洛算法近似 ELBO
中第一项,即
E
q
(
w
j
)
[
log
p
(
D
j
∣
w
j
)
]
\mathbb{E}_{q\left(\boldsymbol{w}_j\right)}\left[\log p\left(\mathcal{D}_j\mid\boldsymbol{w}_j\right)\right]
Eq(wj)[logp(Dj∣wj)] 的值,因此,当进行
K
K
K 次采样时,目标函数变为
n
j
K
∑
k
=
1
K
f
j
(
D
j
;
w
j
,
k
)
−
KL
[
q
(
w
j
∥
p
(
w
j
∣
w
,
ρ
j
2
)
)
]
\frac{n_j}{K}\sum\limits_{k=1}^Kf_j\left(\mathcal{D}_j;\boldsymbol{w}_{j,k}\right) - \operatorname{KL}\left[q\left(\boldsymbol{w}_j\Vert p\left(\boldsymbol{w}_j\mid\boldsymbol{w}, \rho_j^2\right)\right)\right]
Knjk=1∑Kfj(Dj;wj,k)−KL[q(wj∥p(wj∣w,ρj2))]
算法伪代码如下,一共进行
T
T
T 轮迭代,每轮迭代中,所有客户先用全局模型参数本地更新自己的模型参数,其中
μ
\boldsymbol{\mu}
μ 代表 head model
参数(用于学习特征表达),
θ
\boldsymbol{\theta}
θ 代表 base model
参数(用于执行分类)。然后服务器随机选择一批客户,被选中的客户将自己的置信度分数与模型参数一并上传服务器。服务器根据权重计算公式加权聚合得到新一轮的全局模型参数并将其返还给客户。