### WGAN-GP 的 PyTorch 实现
以下是基于 PyTorch 框架实现 Wasserstein GAN with Gradient Penalty (WGAN-GP) 的简化 Python 代码示例:
```python
import torch
from torch import nn, optim
from torchvision.utils import save_image
class Generator(nn.Module):
def __init__(self, nz=100, ngf=64, nc=1):
super(Generator, self).__init__()
self.main = nn.Sequential(
# 输入是 Z 向量
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, input_):
return self.main(input_)
class Discriminator(nn.Module):
def __init__(self, ndf=64, nc=1):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
# 输入大小为 nc x 64 x 64
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False)
)
def forward(self, input_):
return self.main(input_).view(-1)
def compute_gradient_penalty(disc, real_samples, fake_samples, device='cpu'):
"""计算梯度惩罚"""
alpha = torch.rand(real_samples.size(0), 1, 1, 1).to(device)
interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
disc_interpolates = disc(interpolates)
gradients = torch.autograd.grad(outputs=disc_interpolates,
inputs=interpolates,
grad_outputs=torch.ones_like(disc_interpolates),
create_graph=True)[0]
gradient_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
return gradient_penalty
# 初始化生成器和判别器
netG = Generator().to('cuda')
netD = Discriminator().to('cuda')
optimizer_G = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.9))
optimizer_D = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.9))
for epoch in range(num_epochs):
for i, data in enumerate(dataloader):
netD.zero_grad()
real_images = data[0].to('cuda')
noise = torch.randn(batch_size, 100, 1, 1, device='cuda')
generated_images = netG(noise)
loss_real = -torch.mean(netD(real_images))
loss_fake = torch.mean(netD(generated_images.detach()))
gp = compute_gradient_penalty(netD, real_images.data, generated_images.data, 'cuda')
d_loss = loss_real + loss_fake + lambda_gp * gp
d_loss.backward(retain_graph=True)
optimizer_D.step()
if i % n_critic == 0:
netG.zero_grad()
g_loss = -torch.mean(netD(generated_images))
g_loss.backward()
optimizer_G.step()
print(f"[Epoch {epoch}/{num_epochs}] "
f"[Batch {i}/{len(dataloader)}] "
f"[D loss: {d_loss.item()}] "
f"[G loss: {g_loss.item()}]")
```
此代码展示了如何构建并训练一个简单的 WGAN-GP 模型。该模型包括定义生成器 `Generator` 和判别器 `Discriminator` 类,以及用于计算梯度惩罚的辅助函数 `compute_gradient_penalty`。