DCGAN代码解析--上手调试

DCGAN代码解析

今天我们将对GAN领域中经典的论文DCGAN做一个简单的解析。

论文地址:Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks-ReadPaper论文阅读平台

1 初始化

import argparse
import os
import numpy as np
import torch
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
opt = parser.parse_args(args=[])
print(opt)
Namespace(b1=0.5, b2=0.999, batch_size=64, channels=1, img_size=32, latent_dim=100, lr=0.0002, n_cpu=8, n_epochs=200, sample_interval=400)

2 数据加载

加载后的数据为 32 * 32 的灰度图

import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets

# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)
from torch.autograd import Variable
import matplotlib.pyplot as plt

def show_img(img, trans=True):
    if trans:
        img = np.transpose(img.detach().cpu().numpy(), (1, 2, 0))  # 把channel维度放到最后
        plt.imshow(img[:, :, 0], cmap="gray")
    else:
        plt.imshow(img, cmap="gray")
    plt.show()
    
mnist = datasets.MNIST("../../data/mnist")

for i in range(3):
    sample = mnist[i][0]
    show_img(np.array(sample), trans=False)


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-eozTGysh-1664249499185)(test_files/test_6_0.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-kI2wdLrr-1664249499187)(test_files/test_6_1.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-eC4sWl8G-1664249499188)(test_files/test_6_2.png)]

trans_resize = transforms.Resize(opt.img_size)
trans_to_tensor = transforms.ToTensor()
trans_normalize = transforms.Normalize([0.5], [0.5]) # x_n = (x - 0.5) / 0.5

print("shape =", np.array(sample).shape, '\n')
print("data =", np.array(sample), '\n')
sample_resize = trans_resize(sample) 
print("(trans_resize) shape =", np.array(sample_resize).shape, '\n')
sample_tensor = trans_to_tensor(sample_resize)
print("(trans_to_tensor) data =", sample_tensor, '\n')
sample_normalize = trans_normalize(sample_tensor)
print("(trans_normalize) data =", sample_normalize, '\n')
shape = (28, 28) 

data = [[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0  67 232  39   0   0   0   0   0]
 [  0   0   0   0  62  81   0   0   0   0   0   0   0   0   0   0   0   0
    0   0 120 180  39   0   0   0   0   0]
 [  0   0   0   0 126 163   0   0   0   0   0   0   0   0   0   0   0   0
    0   2 153 210  40   0   0   0   0   0]
 [  0   0   0   0 220 163   0   0   0   0   0   0   0   0   0   0   0   0
    0  27 254 162   0   0   0   0   0   0]
 [  0   0   0   0 222 163   0   0   0   0   0   0   0   0   0   0   0   0
    0 183 254 125   0   0   0   0   0   0]
 [  0   0   0  46 245 163   0   0   0   0   0   0   0   0   0   0   0   0
    0 198 254  56   0   0   0   0   0   0]
 [  0   0   0 120 254 163   0   0   0   0   0   0   0   0   0   0   0   0
   23 231 254  29   0   0   0   0   0   0]
 [  0   0   0 159 254 120   0   0   0   0   0   0   0   0   0   0   0   0
  163 254 216  16   0   0   0   0   0   0]
 [  0   0   0 159 254  67   0   0   0   0   0   0   0   0   0  14  86 178
  248 254  91   0   0   0   0   0   0   0]
 [  0   0   0 159 254  85   0   0   0  47  49 116 144 150 241 243 234 179
  241 252  40   0   0   0   0   0   0   0]
 [  0   0   0 150 253 237 207 207 207 253 254 250 240 198 143  91  28   5
  233 250   0   0   0   0   0   0   0   0]
 [  0   0   0   0 119 177 177 177 177 177  98  56   0   0   0   0   0 102
  254 220   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 169
  254 137   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 169
  254  57   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 169
  254  57   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 169
  255  94   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 169
  254  96   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 169
  254 153   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 169
  255 153   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0  96
  254 153   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]] 

(trans_resize) shape = (32, 32) 

(trans_to_tensor) data = tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]]) 

(trans_normalize) data = tensor([[[-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         ...,
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.]]]) 

3 模型

3.1生成器

包含1个全连接层和3个卷积层,使用LeakyReLU和Tanh激活函数,使用了BatchNorm和Upsample
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-y8xGwXe6-1664249499189)(figures/BN.png)]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-FeR2L0uH-1664249499190)(figures/resize.png)]

import torch.nn as nn

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.init_size = opt.img_size // 4
        self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img
    
generator = Generator()
print(generator)
Generator(
  (l1): Sequential(
    (0): Linear(in_features=100, out_features=8192, bias=True)
  )
  (conv_blocks): Sequential(
    (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): Upsample(scale_factor=2, mode=nearest)
    (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace)
    (5): Upsample(scale_factor=2, mode=nearest)
    (6): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.2, inplace)
    (9): Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (10): Tanh()
  )
)

3.2判别器

包含4个卷积层和1个全连接层,使用LeakyReLU和Sigmoid激活函数,使用了Dropout和BatchNorm,使用Strided Conv进行下采样
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-OlQtOQTB-1664249499191)(figures/strided.png)]

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(opt.channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = opt.img_size // 2 ** 4
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)

        return validity
    
discriminator = Discriminator()
print(discriminator)
Discriminator(
  (model): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace)
    (2): Dropout2d(p=0.25)
    (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): LeakyReLU(negative_slope=0.2, inplace)
    (5): Dropout2d(p=0.25)
    (6): BatchNorm2d(32, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (7): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): LeakyReLU(negative_slope=0.2, inplace)
    (9): Dropout2d(p=0.25)
    (10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (11): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (12): LeakyReLU(negative_slope=0.2, inplace)
    (13): Dropout2d(p=0.25)
    (14): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
  )
  (adv_layer): Sequential(
    (0): Linear(in_features=512, out_features=1, bias=True)
    (1): Sigmoid()
  )
)

3.3初始化

对卷积层和BatchNorm层进行参数初始化

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)    
Discriminator(
  (model): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace)
    (2): Dropout2d(p=0.25)
    (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): LeakyReLU(negative_slope=0.2, inplace)
    (5): Dropout2d(p=0.25)
    (6): BatchNorm2d(32, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (7): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): LeakyReLU(negative_slope=0.2, inplace)
    (9): Dropout2d(p=0.25)
    (10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (11): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (12): LeakyReLU(negative_slope=0.2, inplace)
    (13): Dropout2d(p=0.25)
    (14): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
  )
  (adv_layer): Sequential(
    (0): Linear(in_features=512, out_features=1, bias=True)
    (1): Sigmoid()
  )
)

4 损失函数

使用 二项交叉熵(Binary Cross Entropy, BCE)Loss
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-o3tCo2KR-1664249499192)(figures/BCE-loss.png)]

# Loss function
adversarial_loss = torch.nn.BCELoss()

5 Cuda加速

cuda = True if torch.cuda.is_available() else False
print("cuda_is_available =", cuda)
if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
cuda_is_available = True

6 优化器

使用Adam优化器

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
print("learning_rate =", opt.lr)
learning_rate = 0.0002

7 创建输入

分别从数据集和随机向量中获取输入

for i, (imgs, labels) in list(enumerate(dataloader))[:1]:
    # Configure input
    real_imgs = Variable(imgs.type(Tensor))
    # Sample noise as generator input
    z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
    print("i =", i, '\n')
    print("shape of z =", z.shape, '\n')
    print("shape of real_imgs =", real_imgs.shape, '\n')
    print("z =", z, '\n')
    print("real_imgs =")
    for img in real_imgs[:3]:
        show_img(img)
i = 0 

shape of z = torch.Size([64, 100]) 

shape of real_imgs = torch.Size([64, 1, 32, 32]) 

z = tensor([[ 3.1224e-01, -1.1344e-01, -1.0401e+00,  ...,  1.8232e-01,
         -1.2940e+00,  1.3365e+00],
        [ 7.3029e-01,  4.0669e-01, -1.3267e-01,  ..., -4.9197e-01,
         -7.5093e-01, -1.1240e+00],
        [ 1.2938e+00,  7.8608e-01,  1.8455e-01,  ..., -5.0269e-01,
          7.9739e-01, -5.3891e-02],
        ...,
        [-7.9207e-01, -4.8256e-02,  4.5883e-01,  ...,  1.2142e+00,
          6.2461e-01, -1.5289e+00],
        [-1.4916e-03,  4.8395e-01, -3.0754e-01,  ..., -1.8773e-01,
         -5.0988e-01, -1.2065e+00],
        [ 1.2712e+00, -5.0849e-01,  6.2769e-01,  ...,  1.0904e+00,
          2.1514e-01, -4.0929e-01]], device='cuda:0') 

real_imgs =

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-xKXVVKIB-1664249499192)(test_files/test_21_1.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-SxHmPOLV-1664249499193)(test_files/test_21_2.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3X41G5yf-1664249499194)(test_files/test_21_3.png)]

8 计算loss,反向传播

分别对生成器和判别器计算loss,使用反向传播更新模型参数

    # Adversarial ground truths
    batch_size = imgs.shape[0]
    valid = Variable(Tensor(batch_size, 1).fill_(1.0), requires_grad=False) # 为1时判定为真
    fake = Variable(Tensor(batch_size, 1).fill_(0.0), requires_grad=False) # 为0时判定为假
    
    # ---------------------
    #  Train Generator
    # ---------------------
    
    optimizer_G.zero_grad()

    # Sample noise as generator input
    z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

    # Generate a batch of images
    gen_imgs = generator(z)

    # Loss measures generator's ability to fool the discriminator
    g_loss = adversarial_loss(discriminator(gen_imgs), valid)

    g_loss.backward()
    optimizer_G.step()

    # ---------------------
    #  Train Discriminator
    # ---------------------

    optimizer_D.zero_grad()

    # Measure discriminator's ability to classify real from generated samples
    real_loss = adversarial_loss(discriminator(real_imgs), valid)
    fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
    d_loss = (real_loss + fake_loss) / 2
    print("real_loss =", real_loss, '\n')
    print("fake_loss =", fake_loss, '\n')
    print("d_loss =", d_loss, '\n')    
    
    d_loss.backward()
    optimizer_D.step()
real_loss = tensor(0.7088, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>) 

fake_loss = tensor(0.6778, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>) 

d_loss = tensor(0.6933, device='cuda:0', grad_fn=<DivBackward0>) 

9 保存生成图像和模型文件

    from torchvision.utils import save_image

    def sample_image(n_row, batches_done):
        """Saves a grid of generated digits ranging from 0 to n_classes"""
        # Sample noise
        z = Variable(Tensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))
        # Get labels ranging from 0 to n_classes for n rows
        gen_imgs = generator(z)
        save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)
    
    epoch = 0 # temporary
    batches_done = epoch * len(dataloader) + i
    if batches_done % opt.sample_interval == 0:
        os.makedirs("images", exist_ok=True)
        sample_image(n_row=10, batches_done=batches_done)
        
        os.makedirs("model", exist_ok=True) # 保存模型
        torch.save(generator, 'model/generator.pkl') 
        torch.save(discriminator, 'model/discriminator.pkl')
        
        print("gen images saved!\n")
        print("model saved!")
gen images saved!

model saved!

rue)

epoch = 0 # temporary
batches_done = epoch * len(dataloader) + i
if batches_done % opt.sample_interval == 0:
    os.makedirs("images", exist_ok=True)
    sample_image(n_row=10, batches_done=batches_done)
    
    os.makedirs("model", exist_ok=True) # 保存模型
    torch.save(generator, 'model/generator.pkl') 
    torch.save(discriminator, 'model/discriminator.pkl')
    
    print("gen images saved!\n")
    print("model saved!")

    gen images saved!
    
    model saved!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Yuetianw

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值