Wasserstein GAN

概要

Wasserstein GAN(简称WGAN)提出了一种Wasserstein损失,为了解决传统GAN训练中存在的一些问题,如训练不稳定和模式崩溃等。

Wasserstein距离

Wasserstein距离的定义如下:
W ( P r , P g ) = s u p ∣ ∣ f ∣ ∣ L ≤ 1 E x ∼ P r [ f ( x ) ] − E x ~ ∼ P g [ f ( x ~ ) ] W(P_r,P_g)=\underset{||f||_L\leq1}{sup} \mathbb{E}_{x\thicksim P_r}[f(x)]-\mathbb{E}_{\widetilde{x}\thicksim P_g}[f(\widetilde{x})] W(Pr,Pg)=∣∣fL1supExPr[f(x)]Ex Pg[f(x )]

判别器 D ( x ) D(x) D(x)对应 f ( x ) f(x) f(x),这个公式的包括两个部分,首先是两个分布的距离要大,第二是对判别器有一个约束。 ∣ ∣ f ∣ ∣ L ≤ 1 ||f||_L\leq1 ∣∣fL1意味着 f f f必须是1-Lipschitz函数(定义如下)。

在这里插入图片描述
sup表示上确界,这里表示在所有1-Lipchitz函数 f f f上取上确界。在WGAN中,1-Lipchitz性是通过梯度惩罚实现的。

WGAN判别器损失函数

从上一节可以看出,判别器主要是为了能准确识别出源域和目标域,也就是最大化二者的Wasserstein距离,即
m a x D E x ∼ P r [ f ( x ) ] − E x ~ ∼ P g [ f ( x ~ ) ] \underset{D}{max}\mathbb{E}_{x\thicksim P_r}[f(x)]-\mathbb{E}_{\widetilde{x}\thicksim P_g}[f(\widetilde{x})] DmaxExPr[f(x)]Ex Pg[f(x )]
其中, P r P_r Pr是真实分布, P g P_g Pg是生成器的分布,在训练过程中,将最大化目标转化味最小化该目标的负值:
L D = − E x ∼ P r [ f ( x ) ] + E x ~ ∼ P g [ f ( x ~ ) ] L_D=-\mathbb{E}_{x\thicksim P_r}[f(x)]+\mathbb{E}_{\widetilde{x}\thicksim P_g}[f(\widetilde{x})] LD=ExPr[f(x)]+Ex Pg[f(x )]
真实样本的得分应该尽可能高,生成样本的得分应尽可能低,最大化二者分数的差异。加入梯度惩罚后:
L D = − E x ∼ P r [ f ( x ) ] + E x ~ ∼ P g [ f ( x ~ ) ] + λ ⋅ g r a d i e n t p e n a l t y L_D=-\mathbb{E}_{x\thicksim P_r}[f(x)]+\mathbb{E}_{\widetilde{x}\thicksim P_g}[f(\widetilde{x})]+\lambda\cdot gradient penalty LD=ExPr[f(x)]+Ex Pg[f(x )]+λgradientpenalty
其中,梯度惩罚的计算为:
g r a d i e n t p e n a l t y = E x ^ ∼ P i n t e r p ( ∣ ∣ ∇ x ^ D ( x ^ ) ∣ ∣ 2 − 1 ) 2 gradient penalty=\mathbb{E}_{\widehat{x}\thicksim P_{interp}}(||\nabla_{\widehat{x}}D(\widehat{x})||_2-1)^2 gradientpenalty=Ex Pinterp(∣∣x D(x )21)2
P i n t e r p P_{interp} Pinterp是在真实和生成样本之间随机插值得到的分布。

梯度惩罚损失项代码:

def compute_gradient_penalty(D,real_samples, fake_samples):
	#D是判别器,real_smaple是真实样本,fake_sample是虚假样本,也就是生成样本
    alpha = torch.cuda.FloatTensor(np.random.random((real_samples.size(0),1,1)))
    #这里的alpha是初始化插值权重
    interpolates = (alpha * real_samples +(1-alpha)*fake_samples).requires_grad_(True)
	#根据权重、真实样本和生成样本计算插值样本,require_grad 这里让张量可以计算梯度
    d_interpolates = D(interpolates)
    fake = Variable(torch.cuda.FloatTensor(np.ones(d_interpolates.shape)), requires_grad=False)
    gradients = autograd.grad(outputs=d_interpolates,
                              inputs=interpolates,
                              grad_outputs=fake,
                              create_graph=True, retain_graph=True, only_inputs=True)[0]
	#outputs是判别器输出,inputs是插值样本,grad_outputs使用单位张量作为梯度的权重
    gradients = gradients.view(gradients.size(0), -1)
    #将梯度展平成一维
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    #计算二范数
    return gradient_penalty


WGAN生成器损失函数

生成器的任务就是使Wasserstein距离越小越好,由于Wasserstein距离依赖于判别器的输出,生成器作用的仅有生成样本,因此生成器的损失可以表示为:
L G = − E x ~ ∼ P g [ D ( x ~ ) ] L_G=-\mathbb{E}_{\widetilde{x}\thicksim P_g}[D(\widetilde{x})] LG=Ex Pg[D(x )]
也就是提高生成样本的分数,使其更接近真实分布。

### Wasserstein Loss Function in Machine Learning The Wasserstein loss function, also known as Earth Mover's Distance (EMD), measures the distance between two probability distributions. This metric has become particularly popular within Generative Adversarial Networks (GANs). Unlike traditional GAN losses which rely on Jensen-Shannon divergence leading sometimes to issues such as vanishing gradients during training, the Wasserstein loss provides a more meaningful gradient almost everywhere[^2]. In practice, implementing the exact form of the Wasserstein distance involves solving a complex linear programming problem; hence approximations are often employed instead. One common approach is through the addition of a Lipschitz constraint enforced via weight clipping or gradient penalty methods. When applied specifically to GAN architectures, the critic (discriminator) learns to estimate this distance rather than simply distinguishing real from fake samples. The generator then aims at minimizing this estimated distance thereby improving sample quality over iterations. Such modifications lead not only to stabler training processes but potentially higher-quality outputs too[^1]. For illustrative purposes, here’s how one might define a simple version of WGAN with Gradient Penalty using PyTorch: ```python import torch from torch import nn class Critic(nn.Module): def __init__(self, input_dim): super(Critic, self).__init__() self.model = nn.Sequential( nn.Linear(input_dim, 256), nn.ReLU(), nn.Linear(256, 1) ) def forward(self, x): return self.model(x) def compute_gradient_penalty(critic, real_samples, fake_samples): """Calculates the gradient penalty.""" alpha = torch.rand(real_samples.size(0), 1) interpolates = alpha * real_samples + ((1 - alpha) * fake_samples) d_interpolates = critic(interpolates) gradients = torch.autograd.grad(outputs=d_interpolates, inputs=interpolates, grad_outputs=torch.ones(d_interpolates.size()), create_graph=True)[0] gradient_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() return gradient_penalty ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值