昇思25天学习打卡营第21天 | GAN图像生成
GAN模型
生成对抗网络(Generative Adversarial Networks, GAN)是一种生成式机器学习模型,主要由两个不同的模型共同构成:
- 生成器(Generator):负责生成看起来像训练图像的“假”图像;
- 判别器(Discriminator):判断生成器输出的图像是真实训练图像还是虚假的图像。
GAN通过两个网络模块的互相博弈完成训练。
用 x x x代表图像数据, D ( x ) D(x) D(x)表示判别器给出图像为真实图像的概率。当 x x x来自于训练数据时, D ( x ) D(x) D(x)趋近于 1 1 1;当 x x x来自生成器时, D ( x ) D(x) D(x)趋近于 0 0 0。
用 z z z代表标准正态分布中提取出的隐码(隐向量), G ( z ) G(z) G(z)表示将隐码 z z z映射到数据空间的生成器函数。函数 G ( z ) G(z) G(z)的目的时将服从高斯分布的随机噪声 z z z通过生成网络变换为近似于真实分布 p d a t a ( x ) p_{data}(x) pdata(x)的数据分布。目的是找到 θ \theta θ使得 P G ( x ; θ ) P_G(x;\theta) PG(x;θ)和 p d a t a ( x ) p_{data}(x) pdata(x)尽可能的接近,其中 θ \theta θ代表网络参数。
用 D ( G ( z ) ) D(G(z)) D(G(z))代表生成器 G G G生成的假图像被判定为真实图像的概率。
在模型训练中,判别器 D D D想要最大程度的正确分类真实图像与假图像,即参数 log D ( x ) \log D(x) logD(x);而生成器 G G G试图七篇 D D D来最小化假图像被识别到的概率,即参数 log ( 1 − D ( G ( z ) ) ) \log(1-D(G(z))) log(1−D(G(z)))。
因此GAN的损失函数为:
min
G
max
D
V
(
D
,
G
)
=
E
x
∼
p
d
a
t
a
(
x
)
[
log
D
(
x
)
]
+
E
z
∼
p
z
(
z
)
[
log
(
1
−
D
(
G
(
z
)
)
)
]
\min_G\max_D V(D,G)=E_{x\sim p_{data}}(x)[\log D(x)]+E_{z\sim p_z(z)}[\log(1-D(G(z)))]
GminDmaxV(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]
数据集
实验使用MNIST手写数字数据集,共有70000张手写数字图片,包含60000张训练样本和10000张测试样本,数字图片为二进制文件,大小为 28 × 28 28\times28 28×28,单通道。
数据下载
# 数据下载
from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip"
download(url, ".", kind="zip", replace=True)
数据加载
import numpy as np
import mindspore.dataset as ds
batch_size = 64
latent_size = 100 # 隐码的长度
train_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/train')
test_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/test')
def data_load(dataset):
dataset1 = ds.GeneratorDataset(dataset, ["image", "label"], shuffle=True, python_multiprocessing=False,num_samples=10000)
# 数据增强
mnist_ds = dataset1.map(
operations=lambda x: (x.astype("float32"), np.random.normal(size=latent_size).astype("float32")),
output_columns=["image", "latent_code"])
mnist_ds = mnist_ds.project(["image", "latent_code"])
# 批量操作
mnist_ds = mnist_ds.batch(batch_size, True)
return mnist_ds
mnist_ds = data_load(train_dataset)
隐码构造
为了跟踪生成器的学习进度,在训练的每轮迭代结束后,将一组固定的遵循高斯分布的隐码test_noise
输入到生成器中,通过固定隐码生成的图像效果来评估生成器的好坏。
import random
import numpy as np
from mindspore import Tensor
from mindspore.common import dtype
# 利用随机种子创建一批隐码
np.random.seed(2323)
test_noise = Tensor(np.random.normal(size=(25, 100)), dtype.float32)
random.shuffle(test_noise)
模型构建
生成器
from mindspore import nn
import mindspore.ops as ops
img_size = 28 # 训练图像长(宽)
class Generator(nn.Cell):
def __init__(self, latent_size, auto_prefix=True):
super(Generator, self).__init__(auto_prefix=auto_prefix)
self.model = nn.SequentialCell()
# [N, 100] -> [N, 128]
# 输入一个100维的0~1之间的高斯分布,然后通过第一层线性变换将其映射到256维
self.model.append(nn.Dense(latent_size, 128))
self.model.append(nn.ReLU())
# [N, 128] -> [N, 256]
self.model.append(nn.Dense(128, 256))
self.model.append(nn.BatchNorm1d(256))
self.model.append(nn.ReLU())
# [N, 256] -> [N, 512]
self.model.append(nn.Dense(256, 512))
self.model.append(nn.BatchNorm1d(512))
self.model.append(nn.ReLU())
# [N, 512] -> [N, 1024]
self.model.append(nn.Dense(512, 1024))
self.model.append(nn.BatchNorm1d(1024))
self.model.append(nn.ReLU())
# [N, 1024] -> [N, 784]
# 经过线性变换将其变成784维
self.model.append(nn.Dense(1024, img_size * img_size))
# 经过Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间
self.model.append(nn.Tanh())
def construct(self, x):
img = self.model(x)
return ops.reshape(img, (-1, 1, 28, 28))
net_g = Generator(latent_size)
net_g.update_parameters_name('generator')
判别器
# 判别器
class Discriminator(nn.Cell):
def __init__(self, auto_prefix=True):
super().__init__(auto_prefix=auto_prefix)
self.model = nn.SequentialCell()
# [N, 784] -> [N, 512]
self.model.append(nn.Dense(img_size * img_size, 512)) # 输入特征数为784,输出为512
self.model.append(nn.LeakyReLU()) # 默认斜率为0.2的非线性映射激活函数
# [N, 512] -> [N, 256]
self.model.append(nn.Dense(512, 256)) # 进行一个线性映射
self.model.append(nn.LeakyReLU())
# [N, 256] -> [N, 1]
self.model.append(nn.Dense(256, 1))
self.model.append(nn.Sigmoid()) # 二分类激活函数,将实数映射到[0,1]
def construct(self, x):
x_flat = ops.reshape(x, (-1, img_size * img_size))
return self.model(x_flat)
net_d = Discriminator()
net_d.update_parameters_name('discriminator')
损失函数和优化器
lr = 0.0002 # 学习率
# 损失函数
adversarial_loss = nn.BCELoss(reduction='mean')
# 优化器
optimizer_d = nn.Adam(net_d.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g = nn.Adam(net_g.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g.update_parameters_name('optim_g')
optimizer_d.update_parameters_name('optim_d')
模型训练
模型训练包含两个部分:
- 训练判别器。通过提高随机梯度来更新判别器,最大化 log D ( x ) + log ( 1 − D ( G ( z ) ) \log D(x)+\log(1-D(G(z)) logD(x)+log(1−D(G(z))。
- 训练生成器。最小化 log ( 1 − D ( G ( z ) ) \log(1-D(G(z)) log(1−D(G(z))来训练生成器,以产生更好的虚假图像。
import os
import time
import matplotlib.pyplot as plt
import mindspore as ms
from mindspore import Tensor, save_checkpoint
total_epoch = 12 # 训练周期数
batch_size = 64 # 用于训练的训练集批量大小
# 加载预训练模型的参数
pred_trained = False
pred_trained_g = './result/checkpoints/Generator99.ckpt'
pred_trained_d = './result/checkpoints/Discriminator99.ckpt'
checkpoints_path = "./result/checkpoints" # 结果保存路径
image_path = "./result/images" # 测试结果保存路径
# 生成器计算损失过程
def generator_forward(test_noises):
fake_data = net_g(test_noises)
fake_out = net_d(fake_data)
loss_g = adversarial_loss(fake_out, ops.ones_like(fake_out))
return loss_g
# 判别器计算损失过程
def discriminator_forward(real_data, test_noises):
fake_data = net_g(test_noises)
fake_out = net_d(fake_data)
real_out = net_d(real_data)
real_loss = adversarial_loss(real_out, ops.ones_like(real_out))
fake_loss = adversarial_loss(fake_out, ops.zeros_like(fake_out))
loss_d = real_loss + fake_loss
return loss_d
# 梯度方法
grad_g = ms.value_and_grad(generator_forward, None, net_g.trainable_params())
grad_d = ms.value_and_grad(discriminator_forward, None, net_d.trainable_params())
def train_step(real_data, latent_code):
# 计算判别器损失和梯度
loss_d, grads_d = grad_d(real_data, latent_code)
optimizer_d(grads_d)
loss_g, grads_g = grad_g(latent_code)
optimizer_g(grads_g)
return loss_d, loss_g
# 保存生成的test图像
def save_imgs(gen_imgs1, idx):
for i3 in range(gen_imgs1.shape[0]):
plt.subplot(5, 5, i3 + 1)
plt.imshow(gen_imgs1[i3, 0, :, :] / 2 + 0.5, cmap="gray")
plt.axis("off")
plt.savefig(image_path + "/test_{}.png".format(idx))
# 设置参数保存路径
os.makedirs(checkpoints_path, exist_ok=True)
# 设置中间过程生成图片保存路径
os.makedirs(image_path, exist_ok=True)
net_g.set_train()
net_d.set_train()
# 储存生成器和判别器loss
losses_g, losses_d = [], []
for epoch in range(total_epoch):
start = time.time()
for (iter, data) in enumerate(mnist_ds):
start1 = time.time()
image, latent_code = data
image = (image - 127.5) / 127.5 # [0, 255] -> [-1, 1]
image = image.reshape(image.shape[0], 1, image.shape[1], image.shape[2])
d_loss, g_loss = train_step(image, latent_code)
end1 = time.time()
if iter % 10 == 10:
print(f"Epoch:[{int(epoch):>3d}/{int(total_epoch):>3d}], "
f"step:[{int(iter):>4d}/{int(iter_size):>4d}], "
f"loss_d:{d_loss.asnumpy():>4f} , "
f"loss_g:{g_loss.asnumpy():>4f} , "
f"time:{(end1 - start1):>3f}s, "
f"lr:{lr:>6f}")
end = time.time()
print("time of epoch {} is {:.2f}s".format(epoch + 1, end - start))
losses_d.append(d_loss.asnumpy())
losses_g.append(g_loss.asnumpy())
# 每个epoch结束后,使用生成器生成一组图片
gen_imgs = net_g(test_noise)
save_imgs(gen_imgs.asnumpy(), epoch)
# 根据epoch保存模型权重文件
if epoch % 1 == 0:
save_checkpoint(net_g, checkpoints_path + "/Generator%d.ckpt" % (epoch))
save_checkpoint(net_d, checkpoints_path + "/Discriminator%d.ckpt" % (epoch))
总结
这一小节对GAN网络进行了介绍,通过生成器 G G G和判别器 D D D两个网络模型之间的博弈来训练网络,使得判别器难以区分生成的图像是否为虚假图像,最终导致生成器生成的图像接近于真实图像分布。此外,通过将一组固定的隐码输入生成器中,就可以跟踪生成器的学习进度,从而判断输出图像的好坏。