pytorch gan网络_使用pytorch构建自己的生成对抗网络gan


pytorch gan网络

Generative Adversarial Networks is the most interesting idea in machine learning in last ten years.


— Yann Lecun (Facebook AI Director)

— Yann Lecun(Facebook AI总监)

So have you heard about GANs? Or have you just started learning it? GANs was first introduced in 2014 by Ian Goodfellow, a Ph.D. Student at the University of Montreal. The most common example of GANs is generating images. There is a website that contains faces of people that does not exist. This is one of the examples of what GANs can do. That what we are are going to build in this lesson.

那么您是否听说过GAN? 还是刚刚开始学习? GANs由Ian Goodfellow博士于2014年首次提出。 蒙特利尔大学学生。 GAN的最常见示例是生成图像。 有一个网站包含不存在的人的面Kong。 这是GAN可以执行的示例之一。 这就是我们将在本课程中建立的内容。

Generative Adversarial Networks consists of two neural networks Generator and Discriminator competing with each other. I would be explaining each step in detail later in this lesson. If you’re completely unfamiliar with this topic. I would suggest you go through the following lessons first.

生成对抗网络由两个神经网络生成器鉴别器相互竞争。 我将在本课程的后面详细解释每个步骤。 如果您完全不熟悉此主题。 我建议您先学习以下课程。

At the end of this lesson, you’ll be able to train and build your own generative adversarial network from scratch. So without further ado let’s dive in.

在本课程结束时,您将可以从头开始训练并建立自己的生成对抗网络。 因此,事不宜迟,让我们开始吧。

I have also built this generative network in Google Colab here. You can easily open in Google Colabatory and easily follow along.

我还在这里在Google Colab中建立了这个生成网络。 您可以轻松地在Google Colabatory中打开并轻松进行后续操作。

前方的路 (The Road Ahead)

  • Step 0: Import Datasets

  • Step 1: Loading and Preprocessing of Images

  • Step 2: Define the Discriminator Algorithm

  • Step 3: Define the Generator Algorithm

  • Step 4: Write the training Algorithm

  • Step 5: Train the Model

  • Step 6: Test the Model


So Are you excited??? Let’s dive in!

所以你兴奋吗??? 让我们潜入吧!

步骤0:导入数据集 (Step 0: Import Datasets)

The first step is to download and load the data into memory. So we’ll do that here. We’ll be using the CelebFaces Attributes Dataset (CelebA) to train your adversarial networks.

第一步是将数据下载并加载到内存中。 所以我们在这里做。 我们将使用CelebFaces属性数据集(CelebA)来训练您的对抗网络。

  • Download the dataset from here.


  • Unzip the dataset.

  • Clone this Github repository.

    克隆此Github 存储库

After doing that you can either open it in a colab environment or you can use your own pc to train the model.


导入必要的库 (Import the necessary libraries)

It is always considered as a good practice to import all the libraries you’d be using in the first block of the notebook.


#import the neccessary libraries
import pickle as pkl
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch
from torchvision import datasets
from torchvision import transforms
import torch
import torch.optim as optim

步骤1:图像的加载和预处理 (Step 1: Loading And Preprocessing of Images)

In this step, we’re going to preprocess the image data that we’ve downloaded in the previous section.


The following step would be taken


  1. Resize the images

  2. Convert it into tensors

  3. Load it into PyTorch Dataset

  4. Load it into PyTorch DataLoader

    将其加载到PyTorch DataLoader
# Define hyperparameters
batch_size = 32
img_size = 32

# Apply the transformations
transform = transforms.Compose([transforms.Resize(image_size)
# Load the dataset
imagenet_data = datasets.ImageFolder(data_dir,transform= transform)

# Load the image data into dataloader
celeba_train_loader =,

The size of images should be sufficiently small which would help in training the model faster. Tensors are basically NumPy array we’re just converting our images into NumPy array that is necessary for working in PyTorch.

图像的大小应足够小,这将有助于更快地训练模型。 张量基本上是NumPy数组,我们只是将图像转换为在PyTorch中工作所必需的NumPy数组。

Then we’re loading this transformed into a PyTorch Dataset. After that as we’ll be training our data into small batches. This data loader would provide the image data to our model at every iteration.

然后,我们将其转换为PyTorch数据集。 之后,我们将把数据分成小批进行训练。 该数据加载器会在每次迭代时将图像数据提供给我们的模型。

Then we would check the loaded image data by using NumPy and pyplot. This displaying helper function is provided in the notebook. After calling the function you’d get output like this.

然后,我们将使用NumPy和pyplot检查加载的图像数据。 笔记本计算机提供了此显示帮助功能 。 调用该函数后,您将获得如下输出。

Image for post
Images of loaded dataset

As the data is loaded. Now, we can preprocess the images.

随着数据加载。 现在,我们可以预处理图像了。

图像预处理 (Preprocessing of images)

We would be using tanh activated generator in the training. The output of this generator is in the range of -1 to 1. We would need to rescale our images in that range too. So you’ll do just that.

我们将在培训中使用tanh激活发电机。 该生成器的输出在-1到1的范围内。我们也需要在该范围内重新缩放图像。 因此,您将做到这一点。

def scale(img, feature_range=(-1, 1)):
  Scales the input image into given feature_range
    min,max = feature_range
    img = img * (max-min) + min
    return img

This function would rescale any input image. we would be using this function later in the training.

此功能将重新缩放任何输入图像。 我们将在以后的培训中使用此功能。

Now we’re done with the boring preprocessing step.


Here comes the exciting part. Now we need to write code for our generator and discriminator neural networks.

这是令人兴奋的部分。 现在,我们需要为生成器和鉴别器神经网络编写代码。

步骤2:定义识别器算法 (Step 2: Define the Discriminator Algorithm)

Image for post
Source 来源

A discriminator is a neural network that would differentiate between the real and fake images. Both real images and the images generated by the generator would be given to it.

鉴别器是一个神经网络,可以区分真实图像和虚假图像。 真实图像和生成器生成的图像都将被提供给它。

We will first define a helper function that is quite handy in creating the convolutional network layers.


# helper conv function
def conv(in_channels, out_channels, kernel_size, stride=2, padding=1, batch_norm=True):
    layers = []
    conv_layer = nn.Conv2d(in_channels, out_channels, 
                           kernel_size, stride, padding, bias=False)
    #Appending the layer
    #Applying the batch normalization if it's given true
    if batch_norm:
     # returning the sequential container
    return nn.Sequential(*layers)

This helper function is receiving the required arguments necessary to create any convolutional layer and returning a sequential container. Now we will use this helper function to create our own discriminator network.

该辅助函数正在接收创建任何卷积层所需的必需参数,并返回顺序容器。 现在,我们将使用该帮助器功能来创建我们自己的鉴别器网络。

class Discriminator(nn.Module):

    def __init__(self, conv_dim):
        super(Discriminator, self).__init__()

        self.conv_dim = conv_dim

        #32 x 32
        self.cv1 = conv(3, self.conv_dim, 4, batch_norm=False)
        #16 x 16
        self.cv2 = conv(self.conv_dim, self.conv_dim*2, 4, batch_norm=True)
        #4 x 4
        self.cv3 = conv(self.conv_dim*2, self.conv_dim*4, 4, batch_norm=True)
        #2 x 2
        self.cv4 = conv(self.conv_dim*4, self.conv_dim*8, 4, batch_norm=True)
        #Fully connected Layer
        self.fc1 = nn.Linear(self.conv_dim*8*2*2,1)

    def forward(self, x):
        # After passing through each layer
        # Applying leaky relu activation function
        x = F.leaky_relu(self.cv1(x),0.2)
        x = F.leaky_relu(self.cv2(x),0.2)
        x = F.leaky_relu(self.cv3(x),0.2)
        x = F.leaky_relu(self.cv4(x),0.2)
        # To pass throught he fully connected layer
        # We need to flatten the image first
        x = x.view(-1,self.conv_dim*8*2*2)
        # Now passing through fully-connected layer
        x = self.fc1(x)
        return x

步骤3:定义生成器算法 (Step 3: Define the Generator Algorithm)

Image for post
Source 来源

As you can see from the diagram, We give a gaussian or noise vector into the network and it outputs something in S. The “z” on the figure the noise and G(z) on the right is the generated sample.

从图中可以看到,我们将高斯或噪声矢量输入网络,并以S形式输出。图中的“ z”表示噪声,右侧的G(z)是生成的样本。

Same as the discriminator, we will first create a helper function for building our generator network as follows:


def deconv(in_channels, out_channels, kernel_size, stride=2, padding=1, batch_norm=True):
    layers = []
    convt_layer = nn.ConvTranspose2d(in_channels, out_channels, 
                           kernel_size, stride, padding, bias=False)
    # Appending the above conv layer

    if batch_norm:
        # Applying the batch normalization if True
    # Returning the sequential container
    return nn.Sequential(*layers)

Now, it’s time to build the generator network!!


class Generator(nn.Module):
    def __init__(self, z_size, conv_dim):
        super(Generator, self).__init__()

        self.z_size = z_size
        self.conv_dim = conv_dim
        self.fc = nn.Linear(z_size, self.conv_dim*8*2*2)
        self.dcv1 = deconv(self.conv_dim*8, self.conv_dim*4, 4, batch_norm=True)
        self.dcv2 = deconv(self.conv_dim*4, self.conv_dim*2, 4, batch_norm=True)
        self.dcv3 = deconv(self.conv_dim*2, self.conv_dim, 4, batch_norm=True)
        self.dcv4 = deconv(self.conv_dim, 3, 4, batch_norm=False)
        #32 x 32

    def forward(self, x):
        # Passing through fully connected layer
        x = self.fc(x)
        # Changing the dimension
        x = x.view(-1,self.conv_dim*8,2,2)
        # Passing through deconv layers
        # Applying the ReLu activation function
        x = F.relu(self.dcv1(x))
        x= F.relu(self.dcv2(x))
        x= F.relu(self.dcv3(x))
        x= F.tanh(self.dcv4(x))
        #returning the modified image
        return x

To help the model to converge faster we would initialize the weights of linear and convolutional layers. According to the research paper.

为了帮助模型更快地收敛,我们将初始化线性和卷积层的权重。 根据研究论文。

All weights were initialized from a zero-centered Normal distribution with standard deviation 0.02.


We will be defining a function for this purpose as follows:


def weights_init_normal(m):
    classname = m.__class__.__name__
    # For the linear layers
    if 'Linear' in classname:
    # For the convolutional layers
    if 'Conv' in classname or 'BatchNorm2d' in classname:

Now we would initialize the hyperparameter and both networks as follows:


# Defining the model hyperparamameters
d_conv_dim = 32
g_conv_dim = 32
z_size = 100   #Size of noise vector

D = Discriminator(d_conv_dim)
G = Generator(z_size=z_size, conv_dim=g_conv_dim)
# Applying the weight initialization


The output would be something like this:


Image for post
Model Architecture

Discriminator Loss:


For the discriminator, the total loss is the sum of the losses for real and fake images, d_loss = d_real_loss + d_fake_loss.

对于鉴别器,总损失是真实和伪造图像损失的总和, d_loss = d_real_loss + d_fake_loss

Remember that we want the discriminator to output 1 for real images and 0 for fake images, so we need to set up the losses to reflect that.


Source: DCGAN Research Paper

资料来源: DCGAN研究论文

We would be defining two-loss functions. One will be a real loss and the other will be a fake loss as follows:

我们将定义两个损失函数。 一个将是实际损失,另一个将是伪造损失,如下所示:

def real_loss(D_out,smooth=False):
    batch_size = D_out.size(0)
    if smooth:
        labels = torch.ones(batch_size)*0.9
        labels = torch.ones(batch_size)
    labels =
    criterion = nn.BCEWithLogitsLoss()
    loss = criterion(D_out.squeeze(), labels)
    return loss

def fake_loss(D_out):

    batch_size = D_out.size(0)
    labels = torch.zeros(batch_size)
    labels =
    criterion = nn.BCEWithLogitsLoss()
    loss = criterion(D_out.squeeze(), labels)
    return loss

Generator Loss:


The generator loss will look similar only with flipped labels. The generator’s goal is to get the discriminator to think its generated images are real.

仅在标签翻转的情况下,发电机损耗才会看起来相似。 生成器的目标是让鉴别器认为其生成的图像是真实的

Source: DCGAN Research Paper

资料来源: DCGAN研究论文

Now, it’s time to set the optimizers for our networks


I will be using Adam optimizer for our training. As it’s considered to be good for GAN’s. You can choose your own by reading this. The values of the hyperparameters are set according to this research paper. They have experimented with it and these are turned out to be the best!

我将使用Adam优化器进行培训。 由于它被认为对GAN有好处。 您可以通过阅读本章来选择自己的。 超参数的值是根据本研究论文设置的。 他们已经尝试过了,事实证明这是最好的!

lr = 0.0005
beta1 = 0.3
beta2 = 0.999 # default value
# Optimizers
d_optimizer = optim.Adam(D.parameters(), lr, betas=(beta1, beta2))
g_optimizer = optim.Adam(G.parameters(), lr, betas=(beta1, beta2))

步骤4:编写训练算法 (Step 4: Write the training Algorithm)

We have to write our training algorithm for our two neural networks. First, we need to initialize the noise vector. We would keep it fixed throughout the training process.

我们必须为两个神经网络编写训练算法。 首先,我们需要初始化噪声矢量。 我们将在整个培训过程中将其固定。

# Initializing arrays to store losses and samples
samples = []
losses = []

# We need to initilialize fixed data for sampling
# This would help us to evaluate model's performance
fixed_z = np.random.uniform(-1, 1, size=(sample_size, z_size))
fixed_z = torch.from_numpy(fixed_z).float()

For discriminator:


We would first pass the real images through the discriminator network then we would calculate the real loss on it. Then, we would generate fake images and pass them through the discriminator network in order to calculate the fake loss.

我们首先将真实图像通过鉴别器网络,然后计算其真实损失。 然后,我们将生成伪造图像,并将其通过鉴别器网络以计算伪造损失。

After calculating both real and fake losses we would add them and take the optimizer step for training.


# setting optimizer parameters to zero
# to remove previous training data residue

# move real images to gpu memory
real_images =

# Pass through discriminator network
dreal = D(real_images)

# Calculate the real loss
dreal_loss = real_loss(dreal)

# For fake images

# Generating the fake images
z = np.random.uniform(-1, 1, size=(batch_size, z_size))
z = torch.from_numpy(z).float()

# move z to the GPU memory
z =

# Generating fake images by passing it to generator
fake_images = G(z)

# Passing fake images from the disc network        
dfake = D(fake_images)
# Calculating the fake loss
dfake_loss = fake_loss(dfake)

#Adding both lossess
d_loss = dreal_loss + dfake_loss
# Taking the backpropogation step

For Generator:


We would do the same for the training of the generator network. Just now after passing fake images through the discriminator, we would calculate the real loss on it. And then optimize our generator network.

我们将对发电机网络的培训做同样的事情。 刚才,在将伪造图像通过鉴别器传递之后,我们将计算其真实损失。 然后优化我们的发电机网络。

## Training the generator for adversarial loss
#setting gradients to zero

# Generate fake images
z = np.random.uniform(-1, 1, size=(batch_size, z_size))
z = torch.from_numpy(z).float()
# moving to GPU's memory
z =

# Generating Fake images
fake_images = G(z)

# Calculating the generator loss on fake images
# Just flipping the labels for our real loss function
D_fake = D(fake_images)
g_loss = real_loss(D_fake, True)

# Taking the backpropogation step

步骤5:训练模型 (Step 5: Train the Model)

Now we would start training on 100 epochs :D


After training the graph of losses would look something like this.


Image for post
Graph of losses

We can see that the discriminator loss is quite smooth and converging to some specific number even after the 100 epoch. But for the Generator loss, the loss spiked up.

我们可以看到,即使在100个纪元之后,鉴别器的损失也相当平滑,并收敛到某个特定的数字。 但是,对于发电机损耗,损耗急剧上升。

And we can see from the results below that after 60 epochs the generated images are distorted. So we can conclude that the epoch of 60 could be considered as an optimal training epoch.

从下面的结果我们可以看到,经过60个周期后,生成的图像会失真。 因此,我们可以得出结论,将60的纪元视为最佳训练纪元。

As both losses are minimum at that point and the generator is generating some pretty good image!


步骤6:测试模型 (Step 6: Test the Model)

After 10 epochs:


Image for post

After 20 epochs:


Image for post

After 30 epochs:


Image for post

After 40 epochs:


Image for post

After 50 epochs:


Image for post

After 60 epochs:


Image for post

After 70 epochs:


Image for post

After 80 epochs:


Image for post

After 90 epochs:


Image for post

After 100 epochs:


Image for post

结论: (Conclusion:)

We can see that training a Generative Adversarial Network doesn’t mean it would generate good images.


We can see from the results that from 40–60 epochs the generator has generated the images relatively better than the others.


You can try to change the optimizers, learning rate, and other hyperparameters to make it generate better images! Kudos to you for making it so far!

您可以尝试更改优化器,学习率和其他超参数,以使其生成更好的图像! 到目前为止,您很荣幸!

This entire notebook can be found here.


Here’s my Github and Linkedin account. Feel Free to connect.

这是我的GithubLinkedin帐户。 随时连接。


pytorch gan网络





当前余额3.43前往充值 >
领取后你会自动成为博主和红包主的粉丝 规则
钱包余额 0


