wgan-gp 代码使用

本文分享了在Python3环境下,使用TensorFlow实现DCGAN、LSAGN、WGAN及WGAN-GP的代码实践,详细记录了解决版本兼容问题、数据集加载困难及模型保存失败等挑战的过程。

在查看并对比一些代码之后,我找到了下面这个代码,其中包括了DCGAN,LSAGN,WGAN,以及WGAN-GP的代码

地址:https://github.com/LynnHo/DCGAN-LSGAN-WGAN-WGAN-GP-Tensorflow

其中代码的环境要求是

但是由于我使用的python3,所以由于版本问题遇到了以下问题,

1、Argument must be a dense tensor: range(1, 4) - got shape [3], but wanted [].

原因是python2转python3后,range()返回的shape是range(0,3)而不再是list。

图片来自参考网站,https://blog.csdn.net/ygfrancois/article/details/80688265

实际上我在使用代码  train_mnist_wgan_gp.py的时候修改的是这一行

 

2、数据集

readme中提到mnist的数据集在运行代码的时候会自动下载的,我在使用的时候的确自己下载了,但是不知道为什么下载的数据集没有办法正确加载,所以我使用另一个下载的数据集取代

地址:https://github.com/bojone/gan/tree/master/MNIST_data

这个地址也是一个wgan-gp的代码,但是这个代码在使用的时候并不能自行保存模型

另外两个数据集中,卡通数据集比较小,容易下载。

我不能准确下载人脸数据集。

 

3、卡通数据集使用

首先,将卡通数据集在下载mnist的时候产生的data文件夹中解压,然后运行 train_cartoon_wgan_gp.py就可以了

在上面的那行代码中:

31行可以修改图片的路径以及文件夹名称,

32行可以修改图片的形状,当然数据集当中的所有图片需要统一大小。

              

 

### WGAN-GP 的 PyTorch 实现 以下是基于 PyTorch 框架实现 Wasserstein GAN with Gradient Penalty (WGAN-GP) 的简化 Python 代码示例: ```python import torch from torch import nn, optim from torchvision.utils import save_image class Generator(nn.Module): def __init__(self, nz=100, ngf=64, nc=1): super(Generator, self).__init__() self.main = nn.Sequential( # 输入是 Z 向量 nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False), nn.BatchNorm2d(ngf * 8), nn.ReLU(True), nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf * 4), nn.ReLU(True), nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf * 2), nn.ReLU(True), nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf), nn.ReLU(True), nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), nn.Tanh() ) def forward(self, input_): return self.main(input_) class Discriminator(nn.Module): def __init__(self, ndf=64, nc=1): super(Discriminator, self).__init__() self.main = nn.Sequential( # 输入大小为 nc x 64 x 64 nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 2), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 4), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 8), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False) ) def forward(self, input_): return self.main(input_).view(-1) def compute_gradient_penalty(disc, real_samples, fake_samples, device='cpu'): """计算梯度惩罚""" alpha = torch.rand(real_samples.size(0), 1, 1, 1).to(device) interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True) disc_interpolates = disc(interpolates) gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates, grad_outputs=torch.ones_like(disc_interpolates), create_graph=True)[0] gradient_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() return gradient_penalty # 初始化生成器和判别器 netG = Generator().to('cuda') netD = Discriminator().to('cuda') optimizer_G = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.9)) optimizer_D = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.9)) for epoch in range(num_epochs): for i, data in enumerate(dataloader): netD.zero_grad() real_images = data[0].to('cuda') noise = torch.randn(batch_size, 100, 1, 1, device='cuda') generated_images = netG(noise) loss_real = -torch.mean(netD(real_images)) loss_fake = torch.mean(netD(generated_images.detach())) gp = compute_gradient_penalty(netD, real_images.data, generated_images.data, 'cuda') d_loss = loss_real + loss_fake + lambda_gp * gp d_loss.backward(retain_graph=True) optimizer_D.step() if i % n_critic == 0: netG.zero_grad() g_loss = -torch.mean(netD(generated_images)) g_loss.backward() optimizer_G.step() print(f"[Epoch {epoch}/{num_epochs}] " f"[Batch {i}/{len(dataloader)}] " f"[D loss: {d_loss.item()}] " f"[G loss: {g_loss.item()}]") ``` 此代码展示了如何构建并训练一个简单的 WGAN-GP 模型。该模型包括定义生成器 `Generator` 和判别器 `Discriminator` 类,以及用于计算梯度惩罚的辅助函数 `compute_gradient_penalty`。
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值