【深度学习|迁移学习】Wasserstein距离度量和跨域原型一致性损失(CPC Loss)如何计算?以及Wasserstein距离和CPC Loss结合的对抗训练示例,附代码(一)
【深度学习|迁移学习】Wasserstein距离度量和跨域原型一致性损失(CPC Loss)如何计算?以及Wasserstein距离和CPC Loss结合的对抗训练示例,附代码(一)
文章目录
欢迎铁子们点赞、关注、收藏!
祝大家逢考必过!逢投必中!上岸上岸上岸!upupup
大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:
详细信息可参考学术信息专栏:https://blog.csdn.net/2401_89898861/article/details/145551342
前言
Wasserstein距离(Wasserstein Distance,又称为地球移动者距离,Earth Mover’s Distance,EMD)是一种衡量两个分布之间差异的距离度量,它在对抗训练、领域适应等任务中非常有用。跨域原型一致性损失(Cross-domain Prototype Consistency Loss,CPC Loss)则是通过确保不同领域之间的类别原型保持一致来优化特征表示,帮助减少领域间的分布差异。
我们将通过代码示例详细解释Wasserstein距离和跨域原型一致性损失的计算过程。
1. Wasserstein距离(Wasserstein Distance)的计算
Wasserstein距离度量了从一个分布到另一个分布的最优传输成本。对于连续分布,Wasserstein距离可以通过解决优化问题来计算,但对于离散分布,计算则可以通过简单的样本间的最小成本匹配来进行。
1.1 Wasserstein距离的公式
给定两个分布
P
P
P 和
Q
Q
Q,其Wasserstein距离定义为:
其中:
- Π ( P , Q ) Π(P,Q) Π(P,Q)表示从 P P P 到 Q Q Q 的所有联合分布。
- ∥ x − y ∥ ∥x−y∥ ∥x−y∥表示两点之间的距离(通常是欧几里得距离)。
- 目标是找到最佳的运输计划 γ γ γ,使得从 P P P 到 Q Q Q 的距离最小。
对于离散分布,Wasserstein距离可以通过计算所有样本点之间的距离并求最小传输成本来得到。
1.2 Wasserstein距离计算的代码示例(PyTorch)
import torch
import torch.nn.functional as F
def wasserstein_distance(source_features, target_features):
# 计算源领域和目标领域特征之间的欧几里得距离
dist_matrix = torch.cdist(source_features, target_features, p=2) # p=2表示欧几里得距离
# 对距离矩阵进行归一化处理
dist_matrix = dist_matrix / dist_matrix.max() # 归一化到[0, 1]
# 计算Wasserstein距离,使用最小值作为距离度量
wasserstein_dist = dist_matrix.min() # 取最小距离(即最小传输成本)
return wasserstein_dist
解释:
torch.cdist()
计算的是源领域和目标领域特征之间的欧几里得距离矩阵。- 通过将距离矩阵归一化,可以减少尺度差异的影响。
- 最终,计算的Wasserstein距离是源领域和目标领域特征之间的最小传输成本。
2. 跨域原型一致性损失(CPC Loss)的计算
跨域原型一致性损失旨在通过确保源领域和目标领域的原型(类别中心)在特征空间中保持一致,来减少领域间的分布差异。这通常通过最小化源领域和目标领域原型之间的距离来实现。
2.1 CPC Loss的公式
假设我们有源领域
(
D
s
)
(D_s)
(Ds) 和目标领域
(
D
t
)
(D_t)
(Dt) 的样本,以及它们各自的类别原型
(
P
s
和
P
t
)
(P_s和P_t)
(Ps和Pt),CPC损失的计算目标是使得同类别的源领域原型和目标领域原型之间的距离尽量小:
其中:
P
s
x
P^x_s
Psx 和
P
t
c
P^c_t
Ptc分别表示源领域和目标领域的第
c
c
c 类原型。
∥
P
s
c
−
P
t
c
∥
∥P_s^c−P_t^c∥
∥Psc−Ptc∥表示源领域和目标领域相同类别的原型之间的欧几里得距离。
2.2 CPC Loss计算的代码示例(PyTorch)
def cross_domain_prototype_consistency(source_prototypes, target_prototypes):
"""
计算跨域原型一致性损失
source_prototypes: 源领域的类别原型 (类别数 x 特征维度)
target_prototypes: 目标领域的类别原型 (类别数 x 特征维度)
"""
# 计算源领域和目标领域原型之间的欧几里得距离
cpc_loss = torch.sum(torch.pow(source_prototypes - target_prototypes, 2)) # 欧几里得距离的平方和
return cpc_loss
解释:
- source_prototypes和
target_prototypes
分别是源领域和目标领域的类别原型,它们的维度通常是(类别数, 特征维度)。 - 通过计算源领域和目标领域的原型之间的欧几里得距离,并对所有类别求和来得到CPC损失。
下节请参考:【深度学习|迁移学习】Wasserstein距离度量和跨域原型一致性损失(CPC Loss)如何计算?以及Wasserstein距离和CPC Loss结合的对抗训练示例,附代码(二)