WGAN-GP 原理及实现(pytorch版)

WGAN-GP 原理及实现

  • 一、WGAN-GP 原理
    • 1.1 WGAN-GP 核心原理
    • 1.2 WGAN-GP 实现步骤
    • 1.3 总结
  • 二、WGAN-GP 实现
    • 2.1 导包
    • 2.2 数据加载和处理
    • 2.3 构建生成器
    • 2.4 构建判别器
    • 2.5 训练和保存模型
    • 2.6 图片转GIF

一、WGAN-GP 原理

Wasserstein GAN with Gradient Penalty (WGAN-GP) 是对原始 WGAN 的改进,通过梯度惩罚(Gradient Penalty)替代权重裁剪(Weight Clipping),解决了 WGAN 训练不稳定、权重裁剪导致梯度消失或爆炸的问题。


1.1 WGAN-GP 核心原理

(1) Wasserstein 距离(Earth-Mover 距离)

  • 原始 GAN 的 JS 散度在分布不重叠时梯度消失,而 WGAN 使用 Wasserstein 距离衡量生成分布 P g P_g Pg 和真实分布 P r P_r Pr 的距离:
    W ( P r , P g ) = inf ⁡ γ ∼ Π ( P r , P g ) E ( x , y ) ∼ γ [ ∥ x − y ∥ ] W(P_r, P_g) = \inf_{\gamma \sim \Pi(P_r, P_g)} \mathbb{E}_{(x,y)\sim \gamma} [\|x-y\|] W(Pr,Pg)=infγΠ(Pr,Pg)E(x,y)γ[xy]
  • 通过 Kantorovich-Rubinstein 对偶形式,转化为:
    W ( P r , P g ) = sup ⁡ ∥ D ∥ L ≤ 1 E x ∼ P r [ D ( x ) ] − E z ∼ P z [ D ( G ( z ) ) ] W(P_r, P_g) = \sup_{\|D\|_L \leq 1} \mathbb{E}_{x \sim P_r}[D(x)] - \mathbb{E}_{z \sim P_z}[D(G(z))] W(Pr,Pg)=supDL1ExPr[D(x)]EzPz[D(G(z))],其中 D D D 是 1-Lipschitz 函数(梯度范数不超过 1)

(2) 梯度惩罚(Gradient Penalty)

  • 原始 WGAN 的问题:通过权重裁剪强制判别器(Critic)满足 Lipschitz 约束,但会导致梯度不稳定或容量下降
  • WGAN-GP 的改进:直接对判别器的梯度施加惩罚项,强制其梯度范数接近 1: λ ⋅ E x ^ ∼ P x ^ \lambda \cdot \mathbb{E}_{\hat{x} \sim P_{\hat{x}}} λEx^Px^ [ ( ∥ ∇ x ^ D ( x ^ ) ∥ 2 − 1 ) 2 ] \left [(\|\nabla_{\hat{x}} D(\hat{x})\|_2 - 1)^2 \right] [(x^D(x^)21)2]
    • x ^ \hat{x} x^ 是真实数据和生成数据的随机插值点: x ^ = ϵ x + ( 1 − ϵ ) G ( z ) \hat{x} = \epsilon x + (1-\epsilon) G(z) x^=ϵx+(1ϵ)G(z) ϵ ∼ U [ 0 , 1 ] \epsilon \sim U[0,1] ϵU[0,1]
    • λ \lambda λ 是惩罚系数(通常设为 10)

1.2 WGAN-GP 实现步骤

(1) 判别器(Critic)的损失函数
判别器的目标是最大化 Wasserstein 距离,同时满足梯度约束:
L D = E x ∼ P r [ D ( x ) ] − E z ∼ P z [ D ( G ( z ) ) ] ⏟ Wasserstein 距离 + λ ⋅ E x ^ ∼ P x ^ [ ( ∥ ∇ x ^ D ( x ^ ) ∥ 2 − 1 ) 2 ] ⏟ 梯度惩罚 L_D = \underbrace{\mathbb{E}_{x \sim P_r}[D(x)] - \mathbb{E}_{z \sim P_z}[D(G(z))]}_{\text{Wasserstein 距离}} + \underbrace{\lambda \cdot \mathbb{E}_{\hat{x} \sim P_{\hat{x}}} \left[ (\|\nabla_{\hat{x}} D(\hat{x})\|_2 - 1)^2 \right]}_{\text{梯度惩罚}} LD=Wasserstein 距离 ExPr[D(x)]EzPz[D(G(z))]+梯度惩罚 λEx^Px^[(x^D(x^)21)2]

(2) 生成器(Generator)的损失函数
生成器的目标是最小化 Wasserstein 距离: L G = − E z ∼ P z [ D ( G ( z ) ) ] L_G = -\mathbb{E}_{z \sim P_z}[D(G(z))] LG=EzPz[D(G(z))]

(3) 训练流程

  1. 输入:真实数据 x x x,噪声 z ∼ N ( 0 , 1 ) z \sim \mathcal{N}(0,1) zN(0,1)
  2. 生成数据 G ( z ) G(z) G(z)
  3. 插值采样 x ^ = ϵ x + ( 1 − ϵ ) G ( z ) \hat{x} = \epsilon x + (1-\epsilon) G(z) x^=ϵx+(1ϵ)G(z) ϵ ∼ U [ 0 , 1 ] \epsilon \sim U[0,1] ϵU[0,1]
  4. 计算梯度惩罚
    • 对插值样本 x ^ \hat{x} x^ 计算判别器输出 D ( x ^ ) D(\hat{x}) D(x^)
    • 求梯度 ∇ x ^ D ( x ^ ) \nabla_{\hat{x}} D(\hat{x}) x^D(x^) 并计算惩罚项
  5. 更新判别器:最小化 L D L_D LD
  6. 更新生成器:最小化 L G L_G LG(每 n critic n_{\text{critic}} ncritic 次判别器更新后更新 1 次生成器)

1.3 总结

WGAN-GP 通过梯度惩罚替代权重裁剪,显著提升了 WGAN 的训练稳定性,是生成对抗网络的重要改进之一。实际应用中需注意:

  • 判别器架构设计
  • 梯度惩罚的正确实现
  • 学习率和训练次数的调优

二、WGAN-GP 实现

2.1 导包

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as np

import os
import time
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm  
from torchsummary import summary

# 判断是否存在可用的GPU
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 指定存放日志路径
writer=SummaryWriter(log_dir="./runs/wgan_gp")

os.makedirs("./img/wgan_gp_mnist", exist_ok=True) # 存放生成样本目录
os.makedirs("./model", exist_ok=True) # 模型存放目录

2.2 数据加载和处理

# 加载 MNIST 数据集
def load_data(batch_size=64,img_shape=(1,28,28)):
    transform = transforms.Compose([
        transforms.ToTensor(),  # 将图像转换为张量
        transforms.Normalize(mean=[0.5], std=[0.5])  # 归一化到[-1,1]
    ])
    
    # 下载训练集和测试集
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    
    # 创建 DataLoader
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=2,shuffle=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=2,shuffle=False)
    return train_loader, test_loader

2.3 构建生成器

class Generator(nn.Module):
    """生成器"""
    def __init__(self, latent_dim=100,img_shape=(1,28,28)):
        super(Generator,self).__init__()

        # 网络块
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat))
            layers.append(nn.LeakyReLU(negative_slope=0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh() # 输出归一化到[-1,1] 
        )

        
    def forward(self,z): # 噪声z,2维[batch_size,latent_dim]
        gen_img=self.model(z) 
        gen_img=gen_img.view(gen_img.shape[0],*img_shape)
        return gen_img # 4维[batch_size,1,H,W]

2.4 构建判别器

class Discriminator(nn.Module):
    """判别器"""
    def __init__(self,img_shape=(1,28,28)):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
                nn.Linear(int(np.prod(img_shape)), 512),
                nn.LeakyReLU(negative_slope=0.2, inplace=True),
                nn.Linear(512, 256),
                nn.LeakyReLU(negative_slope=0.2, inplace=True),
                nn.Linear(256, 1)
            )

    def forward(self,img): # 输入图片,4维[batc_size,1,H,W]
        img=img.view(img.shape[0], -1) 
        pred = self.model(img)
        return pred # 2维[batch_size,1] 

2.5 训练和保存模型

  • WGAN-GP 算法流程

  • 定义梯度惩罚函数

def compute_gradient_penalty(critic, real, fake, device):
    batch_size = real.shape[0]
    epsilon = torch.rand(batch_size, 1, 1, 1).to(device)  # 随机插值系数
    interpolates = (epsilon * real + (1 - epsilon) * fake).requires_grad_(True)
    critic_interpolates = critic(interpolates)
    
    # 计算梯度
    gradients = torch.autograd.grad(
        outputs=critic_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones_like(critic_interpolates),
        create_graph=True,
        retain_graph=True,
    )[0]
    
    gradients = gradients.view(gradients.shape[0], -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty
  • 训练和保存
# 设置超参数
batch_size = 64
epochs = 200
lr= 0.0002
latent_dim=100 # 生成器输入噪声向量的长度(维数)
sample_interval=400 #每400次迭代保存生成样本

# WGAN的特别设置
num_iter_critic = 5
lambda_gp = 10

# 设置图片形状1*28*28
img_shape = (1,28,28)

# 加载数据
train_loader,_= load_data(batch_size=batch_size,img_shape=img_shape)

# 实例化生成器G、判别器D
G=Generator().to(device)
D=Discriminator().to(device)

# 设置优化器
optimizer_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

# 开始训练
batches_done=0
loader_len=len(train_loader) #训练集加载器的长度
for epoch in range(epochs):
    # 进入训练模式
    G.train()
    D.train()
    
    loop = tqdm(train_loader, desc=f"第{epoch+1}轮")
    for i, (real_imgs, _) in enumerate(loop):
        real_imgs=real_imgs.to(device)  # [B,C,H,W]

        
        # -----------------
        #  训练判别器
        # -----------------
        
        # 获取噪声样本[B,latent_dim)
        z=torch.normal(0,1,size=(real_imgs.shape[0],latent_dim),device=device)  #从正态分布中抽样
 
        # Step-1 计算判断器损失=判断真实图片损失+判断生成图片损失+惩罚项
        fake_imgs=G(z).detach()
        gradient_penalty=compute_gradient_penalty(D, real_imgs, fake_imgs, device)
        dis_loss=-torch.mean(D(real_imgs)) + torch.mean(D(fake_imgs))+lambda_gp*gradient_penalty
       
        # Step-2 更新判别器参数
        optimizer_D.zero_grad() # 梯度清零
        dis_loss.backward() #反向传播,计算梯度
        optimizer_D.step()  #更新判别器 

        # -----------------
        #  训练生成器
        # -----------------
 
        # 判别器每迭代 num_iter_critic 次,生成器迭代一次
        if i % num_iter_critic ==0 :

            gen_imgs=G(z).detach()

            # 更新生成器参数
            optimizer_G.zero_grad() #梯度清零
            gen_loss=-torch.mean(D(gen_imgs))
            gen_loss.backward() #反向传播,计算梯度
            optimizer_G.step()  #更新生成器  

             # 更新进度条
            loop.set_postfix(
                gen_loss=f"{gen_loss:.8f}",
                dis_loss=f"{dis_loss:.8f}"
            )
            

        # 每 sample_interval 次迭代保存生成样本
        if batches_done % sample_interval == 0:
            save_image(gen_imgs.data[:25], f"./img/wgan_gp_mnist/{epoch}_{i}.png", nrow=5, normalize=True)
        batches_done += 1

print('总共训练用时: %.2f min' % ((time.time() - start_time)/60))

#仅保存模型的参数(权重和偏置),灵活性高,可以在不同的模型结构之间加载参数
torch.save(G.state_dict(), "./model/WGAN-GP_G.pth") 
torch.save(D.state_dict(), "./model/WGAN-GP_D.pth") 

2.6 图片转GIF

from PIL import Image

def create_gif(img_dir="./img/wgan_gp_mnist", output_file="./img/wgan_gp_mnist/wgan_gp_figure.gif", duration=100):
    images = []
    img_paths = [f for f in os.listdir(img_dir) if f.endswith(".png")]
    
    # 自定义排序:按 "x_y.png" 的 x 和 y 排序
    img_paths_sorted = sorted(
        img_paths,
        key=lambda x: (
            int(x.split('_')[0]),  # 第一个数字(如 0_400.png 的 0)
            int(x.split('_')[1].split('.')[0])  # 第二个数字(如 0_400.png 的 400)
        )
    )
    
    for img_file in img_paths_sorted:
        img = Image.open(os.path.join(img_dir, img_file))
        images.append(img)
    
    images[0].save(output_file, save_all=True, append_images=images[1:], 
                  duration=duration, loop=0)
    print(f"GIF已保存至 {output_file}")
create_gif()
### WGAN 基本原理 Wasserstein GAN (WGAN) 是一种改进的生成对抗网络(GAN),旨在解决原始 GAN 训练过程中遇到的一些问题,比如模式崩溃和训练不稳定等问题[^1]。传统的 GAN 使用的是 Jensen-Shannon 散度来衡量真实数据分布 \( P_r \) 和生成的数据分布 \( P_g \),这可能导致梯度消失现象,在某些情况下使得模型难以收敛。 为了克服这些问题,WGAN 引入了 Wasserstein 距离作为损失函数的基础。该距离也被称为 Earth Mover's Distance (EMD),它提供了更平滑的距离度量方式,并且对于概率分布之间的差异更加敏感。具体来说,给定两个分布 \( P_r \) 和 \( P_g \),它们之间 Wasserstein 距离定义为: \[ W(P_r, P_g)=\inf _{\gamma \in \Pi\left(P_{r}, P_{g}\right)} \mathbb{E}_{(x, y) \sim \gamma}[d(x, y)] \] 其中 \( d(\cdot,\cdot) \) 表示样本间的某种成本函数,通常取欧几里得范数;\( \Pi(P_r,P_g) \) 则表示所有联合分布在边缘上分别等于 \( P_r \) 和 \( P_g \) 的集合[^3]。 然而直接计算上述表达式非常困难,因此通过 Kantorovich-Rubinstein 对偶理论可以将其转换成更容易处理的形式: \[ W(P_r, P_g)=\sup _{|f|_L \leq 1} \mathbb{E}_{x \sim P_r}[f(x)]-\mathbb{E}_{y \sim P_g}[f(y)] \] 这里 \( |f|_L \leq 1 \) 意味着 Lipschitz 连续条件下的最大斜率为 1 。这个新的形式允许我们利用神经网络去近似最优传输映射 f ,从而简化优化过程[^4]。 ### PyTorch 实现教程 下面是一个简单的例子展示如何使用 PyTorch实现 WGAN : ```python import torch from torch import nn, optim class Generator(nn.Module): def __init__(self, input_dim=100, output_dim=784): # MNIST 图像大小为28*28=784像素点 super().__init__() self.model = nn.Sequential( *[ nn.Linear(input_dim, 256), nn.LeakyReLU(), nn.BatchNorm1d(256), nn.Linear(256, 512), nn.LeakyReLU(), nn.BatchNorm1d(512), nn.Linear(512, output_dim), nn.Tanh() # 输出范围[-1,+1], 需要预处理输入图像至相同区间 ] ) def forward(self, z): return self.model(z) class Critic(nn.Module): # 注意:在WGAN中称为Critic而不是Discriminator def __init__(self, img_shape=(1, 28, 28)): super().__init__() dim = int(torch.prod(torch.tensor(img_shape))) # 将多维张量展平后的维度 self.model = nn.Sequential( *[nn.Linear(dim, 512), nn.LeakyReLU()], *[nn.Linear(512, 256), nn.LeakyReLU()], nn.Linear(256, 1) ) def forward(self, imgs): imgs_flat = imgs.view(imgs.size(0), -1) validity = self.model(imgs_flat) return validity def compute_gradient_penalty(critic, real_samples, fake_samples): """Calculates the gradient penalty loss for WGAN GP""" Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor alpha = Tensor(np.random.random((real_samples.size(0), 1))) interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True) critic_interpolates = critic(interpolates) gradients = autograd.grad(outputs=critic_interpolates, inputs=interpolates, grad_outputs=torch.ones_like(critic_interpolates).to(device), create_graph=True)[0] gradients = gradients.view(gradients.size(0), -1) gradient_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() return gradient_penalty # 初始化参数... latent_dim = 100 n_critic = 5 # 批次更新次数比例(C:D) clip_value = 0.01 # 参数裁剪阈值用于控制权重范围 [-c,c] learning_rate = 0.00005 batches_done = 0 # 绘图计数器初始化... generator = Generator(latent_dim, np.prod(img_shape)) critic = Critic() if cuda: generator.cuda(), critic.cuda() optimizer_G = optim.RMSprop(generator.parameters(), lr=learning_rate) optimizer_D = optim.RMSprop(critic.parameters(), lr=learning_rate) for epoch in range(n_epochs): for i, (imgs, _) in enumerate(dataloader): # Configure input real_imgs = Variable(imgs.type(Tensor)) # --------------------- # Train Discriminator/Critic # --------------------- optimizer_D.zero_grad() # Sample noise as generator input z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim)))) # Generate a batch of images fake_imgs = generator(z).detach() # Real images real_validity = critic(real_imgs) # Fake images fake_validity = critic(fake_imgs) # Gradient penalty gradient_penalty = compute_gradient_penalty(critic, real_imgs.data, fake_imgs.data) # Adversarial loss with gradient penalty term added to objective function. d_loss = - torch.mean(fake_validity)) \ + lambda_gp * gradient_penalty d_loss.backward(retain_graph=True) optimizer_D.step() # Clip weights of discriminator/critic between (-c, c). for p in critic.parameters(): p.data.clamp_(-clip_value, clip_value) # Only update generator every n_critic iterations if i % n_critic == 0: # --- # Train Generator # ----------------- optimizer_G.zero_grad() # Generate new set of samples since last time D was updated gen_z = Variable(Tensor(np.random.normal(0, 1, (batch_size, latent_dim)))) gen_imgs = generator(gen_z) g_loss = -torch.mean(critic(gen_imgs)) g_loss.backward() optimizer_G.step() batches_done += 1 ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

入梦风行

你的鼓励是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值