最优传输论文(十七):Sinkhorn Distances: Lightspeed Computation of Optimal Transport论文原理

本文介绍了Sinkhorn距离,一种通过熵正则化优化的最优传输距离,解决了传统最优传输计算成本高的问题。通过熵约束,Sinkhorn距离可以使用Sinkhorn迭代快速计算,并在MNIST分类任务上表现出优于经典最优传输距离的性能。实验结果显示,Sinkhorn距离在计算速度上可以快几个数量级,并且在适当选择正则化参数λ时,性能更优。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

摘要

  • 这篇文章是sinkhorn的讲解论文。
  • 最佳传输距离是概率测度和特征直方图(histograms of features)的基本距离族(family)。尽管它们具有吸引人的理论性质、在检索任务中的优异性能和直观的公式,但它们的计算涉及线性程序的分辨率(如图像特征尺寸过大),每当这些度量的支持大小或直方图的维数超过几百时,其成本就会迅速变得令人望而却步。在这项工作中,我们提出了一个新的最优运输距离族,从最大化的角度来看待运输问题。我们用熵正则化项平滑了经典的最优传输问题,并证明了所得到的最优值也是一个距离,这个距离可以通过辛霍恩矩阵(sinkhorn)缩放算法以比传输解算器(如网络单纯型法)快几个数量级的速度来计算。我们还表明,这种正则化距离改进了MNIST分类问题上的经典最优运输距离。

介绍

  • 选择合适的距离来比较概率是统计机器学习中的一个关键问题。当对支持这些概率的概率空间知之甚少时,各种带有极小假设的信息散度被提出来扮演这一角色,其中包括海灵格散度(H散度)、χ2散度、全变差散度(total variation)或库尔巴克-莱布勒散度(KL散度)
  • 当概率空间是一个度量空间时,最佳传输距离
### 最优传输在跨模态检索中的应用 #### 什么是最优传输最优传输Optimal Transport, OT)是一种衡量两个概率分布之间距离的方法,其核心在于寻找一种最经济的方式将一个分布转换为另一个分布。这种方法已被广泛应用于机器学习领域,尤其是在生成模型、域适应以及跨模态检索等方面。 #### 最优传输在跨模态检索中的作用 跨模态检索的核心问题是找到一种方式来度量来自不同模态的数据之间的相似性。传统方法通常依赖于手工设计的特征或者简单的线性映射技术。然而,这些方法可能无法充分捕捉复杂的非线性关系。相比之下,最优传输提供了一种强大的工具来解决这一挑战: 1. **分布匹配**:通过计算源模态和目标模态的概率分布之间的 Wasserstein 距离,可以有效地量化两者间的差异并指导特征对齐过程[^1]。 2. **全局优化视角**:相比于局部调整策略,最优传输允许从整体上考虑如何最佳地分配权重以实现更好的表示一致性[^4]。 3. **鲁棒性和泛化能力增强**:由于引入了更精细的距离定义机制,在面对噪声数据或风格变化较大的样本时表现出更强的稳定性[^5]。 #### 实现方法概述 以下是几种常见的基于最优传输理论构建跨模态检索系统的具体做法: ##### 方法一:联合嵌入空间下的最优传输 该方案首先建立统一的低维潜在向量表示形式,接着运用 Sinkhorn 迭代算法求解正则化的 Earth Mover's Distance(EMD),从而完成异构媒体间关联性的刻画工作。此过程中涉及到的关键组件包括但不限于: - 使用双分支架构分别提取各自类型的深层语义描述子; - 定义成本矩阵反映两组对象配对可能性大小的关系; - 基于交替方向乘数法更新参数直至收敛为止。 ```python import ot from sklearn.preprocessing import normalize def compute_ot_distance(source_features, target_features): """ Compute the optimal transport distance between two sets of features. Args: source_features (numpy.ndarray): Features from the source modality. target_features (numpy.ndarray): Features from the target modality. Returns: float: The computed OT distance. """ # Normalize feature vectors to ensure they represent valid probability distributions source_dist = normalize(source_features.sum(axis=1).reshape(-1, 1), norm='l1') target_dist = normalize(target_features.sum(axis=1).reshape(-1, 1), norm='l1') # Define cost matrix based on Euclidean distances or cosine similarities C = np.linalg.norm(source_features[:, None, :] - target_features[None, :, :], axis=-1) # Solve regularized OT problem using Sinkhorn algorithm reg = 0.1 # Regularization parameter P = ot.sinkhorn(source_dist.flatten(), target_dist.flatten(), C, reg) return (C * P).sum() ``` ##### 方法二:对抗训练辅助的最佳运输路径发现 为了进一步提升性能表现,还可以结合生成对抗网络(GAN)的思想来进行端到端的学习流程控制。具体而言就是让判别器尝试区分真实成对比实例与伪造组合案例的同时鼓励生成器创造更加逼真的混合产物出来供前者判断真假之用。这样一来不仅有助于缓解过拟合现象发生几率同时也促进了最终解决方案质量得到显著改善效果明显优于单纯依靠统计规律推导出来的结论版本号等等情况之下均能体现出独特优势所在之处值得深入探讨研究下去才行呢! --- ###
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CtrlZ1

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值