【深度学习|迁移学习】Wasserstein距离度量和跨域原型一致性损失(CPC Loss)如何计算?以及Wasserstein距离和CPC Loss结合的对抗训练示例,附代码(一)

【深度学习|迁移学习】Wasserstein距离度量和跨域原型一致性损失(CPC Loss)如何计算?以及Wasserstein距离和CPC Loss结合的对抗训练示例,附代码(一)

【深度学习|迁移学习】Wasserstein距离度量和跨域原型一致性损失(CPC Loss)如何计算?以及Wasserstein距离和CPC Loss结合的对抗训练示例,附代码(一)



欢迎铁子们点赞、关注、收藏!
祝大家逢考必过!逢投必中!上岸上岸上岸!upupup

大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:
详细信息可参考学术信息专栏:https://blog.csdn.net/2401_89898861/article/details/145551342

前言

Wasserstein距离(Wasserstein Distance,又称为地球移动者距离,Earth Mover’s Distance,EMD)是一种衡量两个分布之间差异的距离度量,它在对抗训练、领域适应等任务中非常有用。跨域原型一致性损失(Cross-domain Prototype Consistency Loss,CPC Loss)则是通过确保不同领域之间的类别原型保持一致来优化特征表示,帮助减少领域间的分布差异。

我们将通过代码示例详细解释Wasserstein距离和跨域原型一致性损失的计算过程。

1. Wasserstein距离(Wasserstein Distance)的计算

Wasserstein距离度量了从一个分布到另一个分布的最优传输成本。对于连续分布,Wasserstein距离可以通过解决优化问题来计算,但对于离散分布,计算则可以通过简单的样本间的最小成本匹配来进行。

1.1 Wasserstein距离的公式

给定两个分布 P P P Q Q Q,其Wasserstein距离定义为:
在这里插入图片描述其中:

  • Π ( P , Q ) Π(P,Q) Π(P,Q)表示从 P P P Q Q Q 的所有联合分布。
  • ∥ x − y ∥ ∥x−y∥ xy表示两点之间的距离(通常是欧几里得距离)。
  • 目标是找到最佳的运输计划 γ γ γ,使得从 P P P Q Q Q 的距离最小。

对于离散分布,Wasserstein距离可以通过计算所有样本点之间的距离并求最小传输成本来得到

1.2 Wasserstein距离计算的代码示例(PyTorch)

import torch
import torch.nn.functional as F

def wasserstein_distance(source_features, target_features):
    # 计算源领域和目标领域特征之间的欧几里得距离
    dist_matrix = torch.cdist(source_features, target_features, p=2)  # p=2表示欧几里得距离
    # 对距离矩阵进行归一化处理
    dist_matrix = dist_matrix / dist_matrix.max()  # 归一化到[0, 1]
    
    # 计算Wasserstein距离,使用最小值作为距离度量
    wasserstein_dist = dist_matrix.min()  # 取最小距离(即最小传输成本)
    return wasserstein_dist

解释:

  • torch.cdist()计算的是源领域和目标领域特征之间的欧几里得距离矩阵。
  • 通过将距离矩阵归一化,可以减少尺度差异的影响。
  • 最终,计算的Wasserstein距离是源领域和目标领域特征之间的最小传输成本。

2. 跨域原型一致性损失(CPC Loss)的计算

跨域原型一致性损失旨在通过确保源领域和目标领域的原型(类别中心)在特征空间中保持一致,来减少领域间的分布差异。这通常通过最小化源领域和目标领域原型之间的距离来实现。

2.1 CPC Loss的公式

假设我们有源领域 ( D s ) (D_s) (Ds) 和目标领域 ( D t ) (D_t) (Dt) 的样本,以及它们各自的类别原型 ( P s 和 P t ) (P_s和P_t) (PsPt),CPC损失的计算目标是使得同类别的源领域原型和目标领域原型之间的距离尽量小:
在这里插入图片描述其中:

P s x P^x_s Psx P t c P^c_t Ptc分别表示源领域和目标领域的第 c c c 类原型。
∥ P s c − P t c ∥ ∥P_s^c−P_t^c∥ PscPtc表示源领域和目标领域相同类别的原型之间的欧几里得距离。

2.2 CPC Loss计算的代码示例(PyTorch)

def cross_domain_prototype_consistency(source_prototypes, target_prototypes):
    """
    计算跨域原型一致性损失
    source_prototypes: 源领域的类别原型 (类别数 x 特征维度)
    target_prototypes: 目标领域的类别原型 (类别数 x 特征维度)
    """
    # 计算源领域和目标领域原型之间的欧几里得距离
    cpc_loss = torch.sum(torch.pow(source_prototypes - target_prototypes, 2))  # 欧几里得距离的平方和
    return cpc_loss

解释:

  • source_prototypes和target_prototypes分别是源领域和目标领域的类别原型,它们的维度通常是(类别数, 特征维度)
  • 通过计算源领域和目标领域的原型之间的欧几里得距离,并对所有类别求和来得到CPC损失。

下节请参考:【深度学习|迁移学习】Wasserstein距离度量和跨域原型一致性损失(CPC Loss)如何计算?以及Wasserstein距离和CPC Loss结合的对抗训练示例,附代码(二)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

努力毕业的小土博^_^

您的鼓励是我创作的动力!谢谢!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值