wasserstein 距离(原理+Pytorch 代码实现)

1.原理

1.1 原论文
最初应用在生成模型上
《Wasserstein GAN》: https://arxiv.org/abs/1701.07875

1.2 全面介绍

https://zhuanlan.zhihu.com/p/25071913

1.3 优点

Wasserstein距离相比KL散度、JS散度的优越性在于,即便两个分布没有重叠,Wasserstein距离仍然能够反映它们的远近。

2.代码


import torch
import torch.nn as nn

# Adapted from https://github.com/gpeyre/SinkhornAutoDiff
class SinkhornDistance(nn.Module):
    r"""
    Given two empirical measures each with :math:`P_1` locations
    :math:`x\in\mathbb{R}^{D_1}` and :math:`P_2` locations :math:`y\in\mathbb{R}^{D_2}`,
    outputs an approximation of the regularized OT cost for point clouds.
    Args:
        eps (float): regularization coefficient
        max_iter (int): maximum number of Sinkhorn iterations
        reduction (string, optional): Specifies the reduction to apply to the output:
            'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
            'mean': the sum of the output will be divided by the number of
            elements in the output, 'sum': the output will be summed. Default: 'none'
    Shape:
        - Input: :math:`(N, P_1, D_1)`, :math:`(N, P_2, D_2)`
        - Output: :math:`(N)` or :math:`()`, depending on `reduction`
    """
    def __init__(self, eps, max_iter, reduction='none'):
        super(SinkhornDistance, self).__init__()
        self.eps = eps
        self.max_iter = max_iter
        self.reduction = reduction

    def forward(self, x, y):
        # The Sinkhorn algorithm takes as input three variables :
        C = self._cost_matrix(x, y).cuda()  # Wasserstein cost function
        x_points = x.shape[-2]
        y_points = y.shape[-2]
        if x.dim() == 2:
            batch_size = 1
        else:
            batch_size = x.shape[0]

        # both marginals are fixed with equal weights
        mu = torch.empty(batch_size, x_points, dtype=torch.float,
                         requires_grad=False).fill_(1.0 / x_points).squeeze().cuda()
        nu = torch.empty(batch_size, y_points, dtype=torch.float,
                         requires_grad=False).fill_(1.0 / y_points).squeeze().cuda()

        u = torch.zeros_like(mu)
        v = torch.zeros_like(nu)
        # To check if algorithm terminates because of threshold
        # or max iterations reached
        actual_nits = 0
        # Stopping criterion
        thresh = 1e-1

        # Sinkhorn iterations
        for i in range(self.max_iter):
            u1 = u  # useful to check the update
            u = self.eps * (torch.log(mu+1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u
            v = self.eps * (torch.log(nu+1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v
            err = (u - u1).abs().sum(-1).mean()

            actual_nits += 1
            if err.item() < thresh:
                break

        U, V = u, v
        # Transport plan pi = diag(a)*K*diag(b)
        pi = torch.exp(self.M(C, U, V))
        # Sinkhorn distance
        cost = torch.sum(pi * C, dim=(-2, -1))

        if self.reduction == 'mean':
            cost = cost.mean()
        elif self.reduction == 'sum':
            cost = cost.sum()

        return cost, pi, C

    def M(self, C, u, v):
        "Modified cost for logarithmic updates"
        "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"
        return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps

    @staticmethod
    def _cost_matrix(x, y, p=2):
        "Returns the matrix of $|x_i-y_j|^p$."
        x_col = x.unsqueeze(-2)
        y_lin = y.unsqueeze(-3)
        C = torch.sum((torch.abs(x_col - y_lin)) ** p, -1)
        return C

    @staticmethod
    def ave(u, u1, tau):
        "Barycenter subroutine, used by kinetic acceleration through extrapolation."
        return tau * u + (1 - tau) * u1
  • 13
    点赞
  • 66
    收藏
    觉得还不错? 一键收藏
  • 6
    评论
PyTorch实现Wasserstein GAN (WGAN) 可分为以下几个步骤: 1. 导入所需的库和模块,包括PyTorch、torchvision、torch.nn、torch.optim和numpy。 2. 定义生成器和判别器网络模型。生成器网络通常由一系列转置卷积层组成,用于将随机噪声向量转换成合成图像。判别器网络通常由一系列卷积层组成,用于将输入图像分类为真(来自训练集)或假(来自生成器)。 3. 定义损失函数和优化器。WGAN使用Wasserstein距离作为判别器网络的损失函数,所以在这一步中需要定义并实现Wasserstein距离函数。优化器可以使用Adam或RMSprop。 4. 定义训练循环。在每个训练步骤中,从真实图像样本中随机采样一批图像,并从生成器网络中生成一批假图像。然后,使用判别器对真实图像和假图像进行分类,并计算判别器和生成器的损失。接下来,使用反向传播和优化器更新判别器和生成器的参数。最后,打印损失并保存生成器的输出图像。 5. 训练模型。使用准备好的数据集,将模型迭代训练多个周期,期间不断优化生成器和判别器的参数。 实现Wasserstein GAN的PyTorch代码如下: ```python import torch import torch.nn as nn import torch.optim as optim import torchvision from torchvision import datasets, transforms # 定义生成器网络模型 class Generator(nn.Module): def __init__(self, ...): ... def forward(self, ...): ... # 定义判别器网络模型 class Discriminator(nn.Module): def __init__(self, ...): ... def forward(self, ...): ... # 定义Wasserstein距离损失函数 def wasserstein_loss(...): ... # 定义生成器和判别器的优化器 generator_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999)) discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999)) # 定义训练循环 for epoch in range(num_epochs): for real_images, _ in data_loader: ... fake_images = generator(noise) real_output = discriminator(real_images) fake_output = discriminator(fake_images) discriminator_loss = wasserstein_loss(real_output, fake_output) generator_loss = -fake_output.mean() discriminator_optimizer.zero_grad() discriminator_loss.backward(retain_graph=True) discriminator_optimizer.step() generator_optimizer.zero_grad() generator_loss.backward() generator_optimizer.step() ... print('Epoch [{}/{}], Discriminator Loss: {:.4f}, Generator Loss: {:.4f}' .format(epoch+1, num_epochs, discriminator_loss.item(), generator_loss.item())) # 保存生成器的输出图像 torchvision.utils.save_image(fake_images, 'generated_images_epoch{}.png'.format(epoch+1)) ``` 这是一个简单的Wasserstein GAN的PyTorch实现,你可以根据具体需求对网络模型、损失函数和优化器等进行调整和优化。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值