DCGAN代码解析
今天我们将对GAN领域中经典的论文DCGAN做一个简单的解析。
1 初始化
import argparse
import os
import numpy as np
import torch
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
opt = parser.parse_args(args=[])
print(opt)
Namespace(b1=0.5, b2=0.999, batch_size=64, channels=1, img_size=32, latent_dim=100, lr=0.0002, n_cpu=8, n_epochs=200, sample_interval=400)
2 数据加载
加载后的数据为 32 * 32 的灰度图
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
datasets.MNIST(
"../../data/mnist",
train=True,
download=True,
transform=transforms.Compose(
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
),
),
batch_size=opt.batch_size,
shuffle=True,
)
from torch.autograd import Variable
import matplotlib.pyplot as plt
def show_img(img, trans=True):
if trans:
img = np.transpose(img.detach().cpu().numpy(), (1, 2, 0)) # 把channel维度放到最后
plt.imshow(img[:, :, 0], cmap="gray")
else:
plt.imshow(img, cmap="gray")
plt.show()
mnist = datasets.MNIST("../../data/mnist")
for i in range(3):
sample = mnist[i][0]
show_img(np.array(sample), trans=False)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-eozTGysh-1664249499185)(test_files/test_6_0.png)]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-kI2wdLrr-1664249499187)(test_files/test_6_1.png)]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-eC4sWl8G-1664249499188)(test_files/test_6_2.png)]
trans_resize = transforms.Resize(opt.img_size)
trans_to_tensor = transforms.ToTensor()
trans_normalize = transforms.Normalize([0.5], [0.5]) # x_n = (x - 0.5) / 0.5
print("shape =", np.array(sample).shape, '\n')
print("data =", np.array(sample), '\n')
sample_resize = trans_resize(sample)
print("(trans_resize) shape =", np.array(sample_resize).shape, '\n')
sample_tensor = trans_to_tensor(sample_resize)
print("(trans_to_tensor) data =", sample_tensor, '\n')
sample_normalize = trans_normalize(sample_tensor)
print("(trans_normalize) data =", sample_normalize, '\n')
shape = (28, 28)
data = [[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 67 232 39 0 0 0 0 0]
[ 0 0 0 0 62 81 0 0 0 0 0 0 0 0 0 0 0 0
0 0 120 180 39 0 0 0 0 0]
[ 0 0 0 0 126 163 0 0 0 0 0 0 0 0 0 0 0 0
0 2 153 210 40 0 0 0 0 0]
[ 0 0 0 0 220 163 0 0 0 0 0 0 0 0 0 0 0 0
0 27 254 162 0 0 0 0 0 0]
[ 0 0 0 0 222 163 0 0 0 0 0 0 0 0 0 0 0 0
0 183 254 125 0 0 0 0 0 0]
[ 0 0 0 46 245 163 0 0 0 0 0 0 0 0 0 0 0 0
0 198 254 56 0 0 0 0 0 0]
[ 0 0 0 120 254 163 0 0 0 0 0 0 0 0 0 0 0 0
23 231 254 29 0 0 0 0 0 0]
[ 0 0 0 159 254 120 0 0 0 0 0 0 0 0 0 0 0 0
163 254 216 16 0 0 0 0 0 0]
[ 0 0 0 159 254 67 0 0 0 0 0 0 0 0 0 14 86 178
248 254 91 0 0 0 0 0 0 0]
[ 0 0 0 159 254 85 0 0 0 47 49 116 144 150 241 243 234 179
241 252 40 0 0 0 0 0 0 0]
[ 0 0 0 150 253 237 207 207 207 253 254 250 240 198 143 91 28 5
233 250 0 0 0 0 0 0 0 0]
[ 0 0 0 0 119 177 177 177 177 177 98 56 0 0 0 0 0 102
254 220 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 169
254 137 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 169
254 57 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 169
254 57 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 169
255 94 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 169
254 96 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 169
254 153 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 169
255 153 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 96
254 153 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0]]
(trans_resize) shape = (32, 32)
(trans_to_tensor) data = tensor([[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]])
(trans_normalize) data = tensor([[[-1., -1., -1., ..., -1., -1., -1.],
[-1., -1., -1., ..., -1., -1., -1.],
[-1., -1., -1., ..., -1., -1., -1.],
...,
[-1., -1., -1., ..., -1., -1., -1.],
[-1., -1., -1., ..., -1., -1., -1.],
[-1., -1., -1., ..., -1., -1., -1.]]])
3 模型
3.1生成器
包含1个全连接层和3个卷积层,使用LeakyReLU和Tanh激活函数,使用了BatchNorm和Upsample
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-y8xGwXe6-1664249499189)(figures/BN.png)]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-FeR2L0uH-1664249499190)(figures/resize.png)]
import torch.nn as nn
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.init_size = opt.img_size // 4
self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))
self.conv_blocks = nn.Sequential(
nn.BatchNorm2d(128),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.BatchNorm2d(128, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, 3, stride=1, padding=1),
nn.BatchNorm2d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
nn.Tanh(),
)
def forward(self, z):
out = self.l1(z)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
img = self.conv_blocks(out)
return img
generator = Generator()
print(generator)
Generator(
(l1): Sequential(
(0): Linear(in_features=100, out_features=8192, bias=True)
)
(conv_blocks): Sequential(
(0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(1): Upsample(scale_factor=2, mode=nearest)
(2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
(4): LeakyReLU(negative_slope=0.2, inplace)
(5): Upsample(scale_factor=2, mode=nearest)
(6): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
(8): LeakyReLU(negative_slope=0.2, inplace)
(9): Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(10): Tanh()
)
)
3.2判别器
包含4个卷积层和1个全连接层,使用LeakyReLU和Sigmoid激活函数,使用了Dropout和BatchNorm,使用Strided Conv进行下采样
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-OlQtOQTB-1664249499191)(figures/strided.png)]
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
def discriminator_block(in_filters, out_filters, bn=True):
block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
if bn:
block.append(nn.BatchNorm2d(out_filters, 0.8))
return block
self.model = nn.Sequential(
*discriminator_block(opt.channels, 16, bn=False),
*discriminator_block(16, 32),
*discriminator_block(32, 64),
*discriminator_block(64, 128),
)
# The height and width of downsampled image
ds_size = opt.img_size // 2 ** 4
self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())
def forward(self, img):
out = self.model(img)
out = out.view(out.shape[0], -1)
validity = self.adv_layer(out)
return validity
discriminator = Discriminator()
print(discriminator)
Discriminator(
(model): Sequential(
(0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(1): LeakyReLU(negative_slope=0.2, inplace)
(2): Dropout2d(p=0.25)
(3): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(4): LeakyReLU(negative_slope=0.2, inplace)
(5): Dropout2d(p=0.25)
(6): BatchNorm2d(32, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
(7): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(8): LeakyReLU(negative_slope=0.2, inplace)
(9): Dropout2d(p=0.25)
(10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
(11): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(12): LeakyReLU(negative_slope=0.2, inplace)
(13): Dropout2d(p=0.25)
(14): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
)
(adv_layer): Sequential(
(0): Linear(in_features=512, out_features=1, bias=True)
(1): Sigmoid()
)
)
3.3初始化
对卷积层和BatchNorm层进行参数初始化
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)
Discriminator(
(model): Sequential(
(0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(1): LeakyReLU(negative_slope=0.2, inplace)
(2): Dropout2d(p=0.25)
(3): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(4): LeakyReLU(negative_slope=0.2, inplace)
(5): Dropout2d(p=0.25)
(6): BatchNorm2d(32, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
(7): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(8): LeakyReLU(negative_slope=0.2, inplace)
(9): Dropout2d(p=0.25)
(10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
(11): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(12): LeakyReLU(negative_slope=0.2, inplace)
(13): Dropout2d(p=0.25)
(14): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
)
(adv_layer): Sequential(
(0): Linear(in_features=512, out_features=1, bias=True)
(1): Sigmoid()
)
)
4 损失函数
使用 二项交叉熵(Binary Cross Entropy, BCE)Loss
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-o3tCo2KR-1664249499192)(figures/BCE-loss.png)]
# Loss function
adversarial_loss = torch.nn.BCELoss()
5 Cuda加速
cuda = True if torch.cuda.is_available() else False
print("cuda_is_available =", cuda)
if cuda:
generator.cuda()
discriminator.cuda()
adversarial_loss.cuda()
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
cuda_is_available = True
6 优化器
使用Adam优化器
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
print("learning_rate =", opt.lr)
learning_rate = 0.0002
7 创建输入
分别从数据集和随机向量中获取输入
for i, (imgs, labels) in list(enumerate(dataloader))[:1]:
# Configure input
real_imgs = Variable(imgs.type(Tensor))
# Sample noise as generator input
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
print("i =", i, '\n')
print("shape of z =", z.shape, '\n')
print("shape of real_imgs =", real_imgs.shape, '\n')
print("z =", z, '\n')
print("real_imgs =")
for img in real_imgs[:3]:
show_img(img)
i = 0
shape of z = torch.Size([64, 100])
shape of real_imgs = torch.Size([64, 1, 32, 32])
z = tensor([[ 3.1224e-01, -1.1344e-01, -1.0401e+00, ..., 1.8232e-01,
-1.2940e+00, 1.3365e+00],
[ 7.3029e-01, 4.0669e-01, -1.3267e-01, ..., -4.9197e-01,
-7.5093e-01, -1.1240e+00],
[ 1.2938e+00, 7.8608e-01, 1.8455e-01, ..., -5.0269e-01,
7.9739e-01, -5.3891e-02],
...,
[-7.9207e-01, -4.8256e-02, 4.5883e-01, ..., 1.2142e+00,
6.2461e-01, -1.5289e+00],
[-1.4916e-03, 4.8395e-01, -3.0754e-01, ..., -1.8773e-01,
-5.0988e-01, -1.2065e+00],
[ 1.2712e+00, -5.0849e-01, 6.2769e-01, ..., 1.0904e+00,
2.1514e-01, -4.0929e-01]], device='cuda:0')
real_imgs =
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-xKXVVKIB-1664249499192)(test_files/test_21_1.png)]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-SxHmPOLV-1664249499193)(test_files/test_21_2.png)]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3X41G5yf-1664249499194)(test_files/test_21_3.png)]
8 计算loss,反向传播
分别对生成器和判别器计算loss,使用反向传播更新模型参数
# Adversarial ground truths
batch_size = imgs.shape[0]
valid = Variable(Tensor(batch_size, 1).fill_(1.0), requires_grad=False) # 为1时判定为真
fake = Variable(Tensor(batch_size, 1).fill_(0.0), requires_grad=False) # 为0时判定为假
# ---------------------
# Train Generator
# ---------------------
optimizer_G.zero_grad()
# Sample noise as generator input
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
# Generate a batch of images
gen_imgs = generator(z)
# Loss measures generator's ability to fool the discriminator
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
g_loss.backward()
optimizer_G.step()
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Measure discriminator's ability to classify real from generated samples
real_loss = adversarial_loss(discriminator(real_imgs), valid)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
print("real_loss =", real_loss, '\n')
print("fake_loss =", fake_loss, '\n')
print("d_loss =", d_loss, '\n')
d_loss.backward()
optimizer_D.step()
real_loss = tensor(0.7088, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>)
fake_loss = tensor(0.6778, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>)
d_loss = tensor(0.6933, device='cuda:0', grad_fn=<DivBackward0>)
9 保存生成图像和模型文件
from torchvision.utils import save_image
def sample_image(n_row, batches_done):
"""Saves a grid of generated digits ranging from 0 to n_classes"""
# Sample noise
z = Variable(Tensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))
# Get labels ranging from 0 to n_classes for n rows
gen_imgs = generator(z)
save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)
epoch = 0 # temporary
batches_done = epoch * len(dataloader) + i
if batches_done % opt.sample_interval == 0:
os.makedirs("images", exist_ok=True)
sample_image(n_row=10, batches_done=batches_done)
os.makedirs("model", exist_ok=True) # 保存模型
torch.save(generator, 'model/generator.pkl')
torch.save(discriminator, 'model/discriminator.pkl')
print("gen images saved!\n")
print("model saved!")
gen images saved!
model saved!
rue)
epoch = 0 # temporary
batches_done = epoch * len(dataloader) + i
if batches_done % opt.sample_interval == 0:
os.makedirs("images", exist_ok=True)
sample_image(n_row=10, batches_done=batches_done)
os.makedirs("model", exist_ok=True) # 保存模型
torch.save(generator, 'model/generator.pkl')
torch.save(discriminator, 'model/discriminator.pkl')
print("gen images saved!\n")
print("model saved!")
gen images saved!
model saved!