Abstract
在预训练的视觉-语言模型中进行提示学习已经在各种下游任务中展现出灵活性,利用其固有的轻量级特性,最近的研究尝试将强大的预训练模型集成到联邦学习框架中,以同时降低通信成本并促进对数据缺乏的本地模型训练
当前的联邦提示学习方法缺乏专门的设计来系统地解决严重的数据异质性,例如,涉及标签和特征偏移的数据分布。本文提出基于最优传输的联邦提示协作(Federated Prompts Cooperation via Optimal Transport,FedOTP),引入了高效的协作提示学习策略,以在每个客户端上捕获多样的类别特征。具体来说,对于每个客户端学习一个全局提示来提取客户端之间的“共识”(consensus)知识,以及一个本地提示来捕获客户端特定的类别特征。然后,利用非平衡最优传输来将视觉特征与这些提示对齐,以在全局共识和本地个性化之间取得平衡。具有多种异质性类型的数据集上进行的实验结果表明,FedOTP优于当前SOTA
Introduction
联邦学习是一种分布式机器学习框架,可以实现模型间的去中心化合作,而无需共享训练数据。由于需要更新和共享模型参数与服务器,当前的联邦学习方法涉及高昂的训练和通信成本。这一限制通常将这些方法局限于一定参数量的Backbone,从而限制了它们的特征空间,导致了性能有限和训练不稳定性
CLIP类视觉-语言预训练模型可以学习适用于各种图像分布的鲁棒和有效表示,与联邦学习的目标相一致。然而,服务器与客户端之间的大量通信开销使得在联邦学习环境中训练CLIP成为一项挑战。此外,当大规模模型使用有限的客户端数据进行训练时,可能会出现过拟合问题。提示学习提供了一种灵活的方式,通过仅训练额外的参数来使预训练模型适应下游任务。这使得提示可以捕获任务特定信息,同时提升固定模型的性能。借助其轻量级特性,先前的研究已经探索了将提示学习集成到联邦学习框架中以克服上述问题的方法
在现实场景中,客户端数据通常表现出域差异(特征偏移)或不平衡的类别分布(标签偏移)。简单地在所有客户端上应用FedAvg方法会导致其偏移客户端的局部分布,导致性能不佳。因此,设计专门的个性化联邦提示学习方法以有效解决数据异质性是至关重要的。FedPrompt通过保持个性化的注意力模块来引入客户端个性化,以在本地学习多客户端共识的同时生成本地的局部视觉特征。然而,在存在显著的标签/特征偏移时,仅仅学习语言模态中的共享提示是不够的。为了解决这些限制,提出了在本地训练阶段同时学习全局共享提示和个性化本地提示的方法
方法:在本地训练后,本地提示保留在本地,而全局提示传输到服务器上,并与来自其他客户端的提示一起使用FedAvg进行聚合。通过这种方式,客户端具有从全局提示中获取客户端之间共识知识的能力,同时还能够通过本地提示识别特定于客户端的用户特征。为进一步实现全局共识和本地个性化之间的平衡,引入了通过最优传输实现的提示协作(FedOTP)。FedOTP利用OT来通过自适应最优传输计划将本地视觉特征与全局和本地文本特征对齐,促进跨模态的细粒度匹配,并加强全局和本地提示之间的协作。自适应的OT传输可以提供对视觉不对齐的弹性和对特征转移的有效适应性*,标准的OT公式对传输添加了两个硬性的平等约束,导致每个图像块被分配到提示*。这可能导致提示捕获图像中一些与类别无关的信息,并因此影响最终结果,因此采用不平衡的OT,通过放松其中一个平等约束,允许提示专注于最相关的图像块而不是整个图像内容。FedOTP中应用了Dykstra算法的快速实现,使其能够迅速收敛并在迭代过程中集中于图像的核心区域
Contributions
- 首个探索在存在严重数据异质性的联邦学习中提示协作,训练一个全局提示来学习客户端之间的共识信息,并同时训练一个本地提示来捕获客户端特定的类别特征
- FedOTP利用非平衡OT来增强全局和本地提示之间合作,通过将本地视觉特征与这两个文本提示进行对齐,能够有效处理严重的数据异质性
- 在各种数据异质性的广泛采用的数据集上进行了大量实验验证有效性
Optimal Transport
最优传输最初作为优化同时移动多个项目成本的解决方案。传统的OT缺乏部分位移的灵活性。不平衡OT通过放松平等约束,并基于Kullback-Leibler散度引入软惩罚,可以通过广义Sinkhorn算法有效地解决。由于其在分布匹配方面的显著能力,OT已被应用于各种理论和实际任务,包括领域自适应、带有噪声标签的学习、因果发现、联邦学习等等。在提示学习领域,PLOT提出了为不同的上下文表示学习多个提示集,并使用OT来对齐视觉和语言模态的特征。与PLOT不同,本文采用非平衡OT来增强全局和本地提示之间的合作,通过放松一个平等约束,使得提示可以专注于最相关的图像块
最优传输是一种旨在在两个分布之间有效地转移概率质量的约束优化问题。这里简要回顾其离散情况的表述,给定两个概率向量
α
α
α 和
β
β
β,以及成本矩阵
C
∈
R
∣
α
∣
×
∣
β
∣
C ∈ \mathbb{R}^{|α|×|β|}
C∈R∣α∣×∣β∣,最优传输旨在通过最小化以下目标来找到最优传输计划
T
T
T
直接优化最优传输问题会耗费大量时间,Sinkhorn算法引入了一个熵正则化项以进行快速优化。正则化的最优传输公式可以表示为:
m
i
n
T
∈
U
(
α
,
β
)
⟨
C
,
T
⟩
+
λ
⟨
T
,
l
o
g
T
⟩
min_{T ∈ U(α,β)} ⟨C, T⟩ + λ⟨T, log T⟩
minT∈U(α,β)⟨C,T⟩+λ⟨T,logT⟩,最优传输计划
T
∗
T^∗
T∗ 具有唯一的形式
T
∗
=
d
i
a
g
(
u
(
t
~
)
)
e
x
p
(
−
C
/
λ
)
d
i
a
g
(
v
(
t
~
)
)
T^∗ = diag(u^{(\tilde{t})}) exp(−C/λ)diag(v^{(\tilde{t})})
T∗=diag(u(t~))exp(−C/λ)diag(v(t~)),其中
(
t
~
)
^{(\tilde{t})}
(t~) 表示迭代次数,在每次迭代中
u
(
t
~
)
=
u
/
(
e
x
p
(
−
C
/
λ
)
v
(
t
~
−
1
)
)
u^{(\tilde{t})} = u/(exp(−C/λ)v^{(\tilde{t}-1)})
u(t~)=u/(exp(−C/λ)v(t~−1)),
v
(
t
~
)
=
v
/
(
e
x
p
(
−
C
/
λ
)
⊤
u
(
t
~
)
)
v^{(\tilde{t})} = v/(exp(−C/λ)^⊤u^{(\tilde{t})})
v(t~)=v/(exp(−C/λ)⊤u(t~))。
Methodology
FedOTP利用非平衡最优传输来加强全局和本地提示之间的协作,有效解决了标签偏移和特征偏移数据异质性的问题
Federated Learning with Global and Local Prompts
为了降低通信成本并解决数据异质性问题,每个客户端在我们的联邦学习设置中配备了一个预训练的CLIP模型和一个提示学习器。通过提示学习器,每个客户端都学习了共享的全局提示和个性化的本地提示,允许客户端提取更个性化的见解,同时保持一定程度的共识。具体来说,对于每个客户端,提示 P i P_i Pi包括全局提示 P g P_g Pg和个性化提示 P l , i P_{l,i} Pl,i,表示为 P i = [ P g , P l , i ] P_i = [P_g, P_{l,i}] Pi=[Pg,Pl,i]。在每一轮通信 t t t期间,客户端 i i i通过梯度下降初始化提示 P t , 0 i = [ P t − 1 g , P t − 1 , l , i ] P_{t,0}^i = [P_{t-1}^g, P_{t-1,l,i}] Pt,0i=[Pt−1g,Pt−1,l,i]
通过梯度下降在本地进行
R
R
R次迭代,联合更新全局和本地提示
P
t
,
r
i
=
P
t
,
r
−
1
i
−
η
∇
L
D
i
(
P
t
,
r
−
1
i
)
P_{t,r}^i = P_{t,r-1}^i - \eta\nabla L_{D_i}(P_{t,r-1}^i)
Pt,ri=Pt,r−1i−η∇LDi(Pt,r−1i)。在本地训练之后,仅更新的全局提示
P
t
,
R
g
,
i
P_{t,R}^g,i
Pt,Rg,i被传输到服务器进行聚合,以学习客户端之间的全局共识,而个性化提示保留在本地以捕获客户端特定的类别特征。聚合过程可以表示为:
FedOTP的目标函数可以表示为:
其中
L
D
i
(
P
g
,
P
l
,
i
)
=
E
(
x
j
i
,
y
j
i
)
∈
D
i
ℓ
(
f
(
P
g
,
P
l
,
i
;
x
j
i
)
,
y
j
i
)
\mathcal{L}_{D_i}(P_g, P_{l,i}) = \mathbb{E}{(x_j^i,y_j^i) \in D_i} \ell(f(P_g, P{l,i}; x_j^i), y_j^i)
LDi(Pg,Pl,i)=E(xji,yji)∈Diℓ(f(Pg,Pl,i;xji),yji),
f
(
P
g
,
P
l
,
i
;
⋅
)
f(P_g, P_{l,i}; \cdot)
f(Pg,Pl,i;⋅)表示客户端
i
i
i的个性化模型,
ℓ
(
⋅
,
⋅
)
\ell(\cdot, \cdot)
ℓ(⋅,⋅)表示交叉熵损失函数
Global-Local Prompt Cooperation by Unbalanced Optimal Transport
将提示 P g P_g Pg和 P l , i P_{l,i} Pl,i初始化为{w1, p1, · · · , ps, · · · , wL},其中wi表示词嵌入,pi表示可学习的向量。通过文本编码器 h ( ⋅ ) h(·) h(⋅),为每个类别 k k k获得了全局文本特征 H k , g = h ( P g , k ) ∈ R d f H_{k,g} = h(P_{g,k}) ∈ R^{df} Hk,g=h(Pg,k)∈Rdf和本地文本特征 H k , l = h ( P l , i , k ) ∈ R d f H_{k,l} = h(P_{l,i,k}) ∈ R^{df} Hk,l=h(Pl,i,k)∈Rdf,两者的结合表示为 H k = [ H k , g , H k , l ] H_k = [H_{k,g}, H_{k,l}] Hk=[Hk,g,Hk,l]。将图像 x x x上传到图像编码器 g ( ⋅ ) g(·) g(⋅),得到一组视觉特征 G = g ( x ) ∈ R ( V + 1 ) × d f G = g(x) ∈ R^{(V +1)×df} G=g(x)∈R(V+1)×df,其中包括一个类别 token G c ∈ R d f G_c ∈ R^{df} Gc∈Rdf和一个特征图 G m ∈ R V × d f G_m ∈ R^{V×df} Gm∈RV×df
考虑学习一个最优传输计划
T
T
T,将全局和本地文本特征
H
k
H_k
Hk与视觉特征图
G
m
G_m
Gm进行对齐。通过将特征表示为离散分布的样本,成本矩阵可以由
H
k
H_k
Hk和
G
m
G_m
Gm之间的余弦距离表示为
C
=
1
−
G
m
⊤
H
k
∈
R
V
×
2
C = 1 - G_m^{\top}H_k ∈ R^{V×2}
C=1−Gm⊤Hk∈RV×2,然后非平衡最优传输的优化目标表示为:
其中
α
∈
R
V
α ∈ R^V
α∈RV和
β
∈
R
2
β ∈ R^2
β∈R2实质上是边缘概率向量,满足
∣
α
∣
1
≥
∣
β
∣
1
=
γ
(
γ
∈
[
0
,
1
]
)
|α|_1 ≥ |β|_1 = γ (γ ∈ [0, 1])
∣α∣1≥∣β∣1=γ(γ∈[0,1])。Eq.(6)与PLOT中的公式有所不同,其使用了带有两个硬等式约束的经典最优传输,如Eq.(3)所示。这会强制提示将每个图像块映射到提示,可能导致它们从图像中捕获一些与类别无关的信息,从而影响最终结果。相比之下,FedOTP放宽了一个等式约束,允许提示仅集中在最相关的图像块上,而不是整个图像内容。此外,通过控制
γ
γ
γ,FedOTP具有调节提示在特征图上的映射大小的能力
为了快速优化,在Eq.(6)中添加了一个熵正则化项,目标函数如下:
进一步将Eq.(7)重新表述为Kullback-Leibler (KL)投影,然后解空间
U
(
α
,
β
)
U(α, β)
U(α,β)被定义为两个凸但不是仿射集的交集:
其中
C
1
C_1
C1和
C
2
C_2
C2分别是
T
12
≤
α
T_{12} ≤ α
T12≤α和
T
1
⊤
V
=
β
T^{\top}_1V = β
T1⊤V=β的解空间。采用了Dykstra算法的快速实现解Eq.(8),通过仅利用矩阵-向量乘法来有效地缩放
C
1
C_1
C1和
C
2
C_2
C2之间的迭代KL投影。
Q
=
exp
(
−
C
/
λ
)
Q = \exp(-C/λ)
Q=exp(−C/λ)和
v
(
0
)
=
1
2
v^{(0)} = \frac{1}{2}
v(0)=21,在几次迭代内可以实现快速优化解决方案:
其中
t
t
t是迭代次数,在每次迭代中,
u
(
t
)
=
min
(
1
Q
α
v
(
t
−
1
)
,
1
V
)
u^{(t)} = \min\left(\frac{1}{Qαv^{(t-1)}}, 1_V\right)
u(t)=min(Qαv(t−1)1,1V),
v
(
t
)
=
1
2
Q
T
β
u
(
t
)
v^{(t)} = \frac{1}{2Q^Tβu^{(t)}}
v(t)=2QTβu(t)1,其中
Q
α
=
Q
/
diag
(
α
)
1
V
×
2
Qα = Q/\text{diag}(α)1_V×2
Qα=Q/diag(α)1V×2和
Q
T
β
=
Q
T
/
diag
(
β
)
1
V
×
2
Q^Tβ = Q^T/\text{diag}(β)1_V×2
QTβ=QT/diag(β)1V×2。通过Eq.(9),获得了最优传输方案
T
∗
T^*
T∗和最终的Wasserstein距离
d
C
,
k
d_{C,k}
dC,k,然后在Eq.(1)中,匹配得分被以下预测概率替换:
在获得
q
(
y
=
k
∣
x
)
q(y = k|x)
q(y=k∣x)之后,同时对客户端
i
i
i的全局和本地提示中的可学习向量
p
a
a
=
1
s
{p_a}{a=1}^s
paa=1s进行交叉熵优化,方法如Eq.(2)所述。然后,使用Eq.(4)将全局提示
P
g
,
i
P_{g,i}
Pg,i发送到服务器进行聚合,同时保留本地提示。通过OT进行本地训练后,FedOTP的最终预测概率是来自全局和本地提示的信息综合。这避免了直接将两个提示的结果相加,促进协作的学习过程
Experiments