wasserstein 距离(原理+Pytorch 代码实现)

1.原理

1.1 原论文
最初应用在生成模型上
《Wasserstein GAN》: https://arxiv.org/abs/1701.07875

1.2 全面介绍

https://zhuanlan.zhihu.com/p/25071913

1.3 优点

Wasserstein距离相比KL散度、JS散度的优越性在于,即便两个分布没有重叠,Wasserstein距离仍然能够反映它们的远近。

2.代码


import torch
import torch.nn as nn

# Adapted from https://github.com/gpeyre/SinkhornAutoDiff
class SinkhornDistance(nn.Module):
    r"""
    Given two empirical measures each with :math:`P_1` locations
    :math:`x\in\mathbb{R}^{D_1}` and :math:`P_2` locations :math:`y\in\mathbb{R}^{D_2}`,
    outputs an approximation of the regularized OT cost for point clouds.
    Args:
        eps (float): regularization coefficient
        max_iter (int): maximum number of Sinkhorn iterations
        reduction (string, optional): Specifies the reduction to apply to the output:
            'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
            'mean': the sum of the output will be divided by the number of
            elements in the output, 'sum': the output will be summed. Default: 'none'
    Shape:
        - Input: :math:`(N, P_1, D_1)`, :math:`(N, P_2, D_2)`
        - Output: :math:`(N)` or :math:`()`, depending on `reduction`
    """
    def __init__(self, eps, max_iter, reduction='none'):
        super(SinkhornDistance, self).__init__()
        self.eps = eps
        self.max_iter = max_iter
        self.reduction = reduction

    def forward(self, x, y):
        # The Sinkhorn algorithm takes as input three variables :
        C = self._cost_matrix(x, y)
<think>好的,用户想了解如何在PyTorch中计算Wasserstein距离。首先,我需要回忆Wasserstein距离的基本概念。它通常用于衡量两个概率分布之间的差异,尤其是在最优运输问题中。根据不同的情况,计算方法可能不同。 对于一维情况,计算相对简单,可以通过排序后计算累积分布差异的积分。而对于高维情况,可能需要使用Sinkhorn算法或者对偶形式。用户可能需要的是代码示例,所以得考虑PyTorch中如何实现这两种情况。 首先,一维情况。假设有两个概率分布,比如两个批量的样本,可以通过排序后计算差的绝对值之和。这里需要注意,样本需要排序,然后逐点计算差异的平均。例如,使用torch.sort对两个样本进行排序,然后计算差的均值。这种方法适用于经验分布,并且计算效率高。 然后是高维情况,这时候可能需要使用Sinkhorn迭代。因为直接计算Wasserstein距离在高维空间计算量很大,Sinkhorn通过引入熵正则化来近似,使得计算可行。需要实现Sinkhorn算法,这涉及到生成成本矩阵,然后进行迭代更新。这里可能会用到矩阵操作,比如计算距离矩阵,然后进行多次迭代,每一步更新对偶变量u和v,最后计算最终的损失值。 另外,用户可能需要了解Wasserstein距离的应用场景,比如生成对抗网络(WGAN)。需要确保代码示例的正确性,比如使用合适的函数和参数。同时,要注意PyTorch中的自动梯度计算,是否需要关闭梯度,特别是在排序步骤中,因为排序操作可能导致梯度问题,需要测试。 还需要检查代码中的变量是否正确,比如在一维情况下,是否处理了批量数据,是否需要求平均。在高维情况下,成本矩阵的计算是否正确,比如使用torch.cdist计算样本之间的距离,正则化参数epsilon和迭代次数的选择是否合理。 最后,生成相关问题,可能涉及Wasserstein距离的其他计算方法、与KL散度的区别、在WGAN中的应用以及性能优化技巧。需要确保问题相关且有价值,帮助用户进一步学习。</think>### PyTorchWasserstein距离的计算方法 Wasserstein距离(又称Earth Mover's Distance)在深度学习中常用于衡量概率分布差异,尤其在生成对抗网络(WGAN)中有重要应用[^1]。以下是两种典型场景的实现方法: --- ### 一、一维分布计算(经验分布) 适用于排序后可直接计算的情况: ```python def wasserstein_1d(u_values, v_values): u_sorted = torch.sort(u_values, dim=0).values v_sorted = torch.sort(v_values, dim=0).values return torch.mean(torch.abs(u_sorted - v_sorted)) ``` 示例用法: ```python # 生成样本数据 samples1 = torch.randn(1000) samples2 = torch.randn(1000) + 2.0 # 计算距离 distance = wasserstein_1d(samples1, samples2) print(f"Wasserstein距离: {distance.item():.4f}") ``` --- ### 二、高维分布计算(Sinkhorn算法) 通过熵正则化实现高效计算: ```python def sinkhorn_loss(x, y, epsilon=0.01, max_iters=100): # 计算成本矩阵 C = torch.cdist(x, y)**2 # L2距离平方 # 初始化对偶变量 u = torch.zeros(x.size(0), device=x.device) v = torch.zeros(y.size(0), device=y.device) for _ in range(max_iters): u = epsilon * (torch.logsumexp((-C + u.unsqueeze(1) + v)/epsilon, dim=1)) v = epsilon * (torch.logsumexp((-C + u.unsqueeze(1) + v)/epsilon, dim=0)) # 计算最终损失 return torch.mean(u) + torch.mean(v) ``` 示例用法: ```python # 生成高维样本 batch1 = torch.randn(64, 256) # 64个256维样本 batch2 = torch.randn(64, 256) # 计算近似Wasserstein距离 loss = sinkhorn_loss(batch1, batch2) print(f"Sinkhorn损失值: {loss.item():.4f}") ``` --- ### 关键点说明: 1. **一维方法原理**:当样本排序后,Wasserstein距离等价于对应分位点差的绝对值平均[^2] 2. **Sinkhorn优势**:通过$O(n^2)$复杂度避免了直接求解线性规划问题 3. **正则化参数**:epsilon值越小越接近真实Wasserstein距离,但数值稳定性会降低
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值