一、网络结构
网络由生成器和判别器两部分组成,网络结构图如下所示:
生成器功能: 生成接近真实样本的假样本
生成器的输入: 随机噪声
生成器的输出: 与真实样本相同大小的假样本
判别器的功能: 判断输入是真样本还是假样本
判别器的输入: 真实样本 和 生成器生成的假样本
判别器的输出: 结果是真实样本的概率值, p∈[0, 1]
二、损失函数
损失函数第一部分: E x~pdata(x)[logD(x)]
从数据集随机抽样出来的真实样本,判别器将其判断为1
损失函数第二部分: Ez~pnoize(z)[log(1-D(G(z)))]
对于生成器而言,希望判别器将G(z)判别为真实数据,即D(G(z))的结果接近1,总体损失越小越好。
对于鉴别器而言,希望正确的判断生成器生成的假数据,即D(G(z))的结果接近0,总体损失越大越好。
由此可见,生成器和鉴别器的优化目标是相反的,即最小最大优化:
三、生成对抗逻辑
一方面,生成器会不断提高自己造假的能力,不断进化,以骗过鉴别器。
另一方面,鉴别器也会不断提高自己的鉴别能力,不断升级,以分别真假样本。
因此,通过这种对抗训练的方式,生成器和鉴别器的能力都被不断提高,从而使得生成器能够生成较为真实的假样本。
其过程大致如下:
-
v1版本的Generator只能生成较为模糊的图片,v1版本的Discriminator很容易鉴别出它是假样本。
-
然后,Generator升级为v2版本,能生成有颜色有眼睛的图片。此时,已经能够骗过v1版本的Discriminator。
-
接着,Discriminator也升级为v2版本,学习到了如何鉴别v2版Generator生成的假样本。
-
然后,Generator再继续升级… 接着,Discriminator也跟着升级…
四、案例实现
利用MINIST数据集,训练一个能生成手写体数字的生成器。
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 读取数据 并对数据进行归一化[-1, 1]
transform = transforms.Compose([
transforms.ToTensor(), # 将图片转换为Tensor,归一化至[0,1];将图片改为channel*height*width
transforms.Normalize(0.5, 0.5) # 标准化至[-1, 1];规定均值和标准差
])
# 训练集
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(trainset, batch_size=16, shuffle=True)
# 定义生成器
# 生成器的输入:长度为100的随机噪声(正太分布随机数)
# 生成器的输出:生成的图片(1*28*28)
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.Linear(100, 256), # 输入为100维的随机噪声,输出为256维的特征
nn.ReLU(), # 激活函数
nn.Linear(256, 512),
nn.ReLU(), # 激活函数v
nn.Linear(512, 28 * 28), # 输出为28*28的图片
nn.Tanh() # 激活函数
)
def forward(self, input):
output = self.main(input)
output = output.view(-1, 28, 28) # 将输出的图片reshape为1*28*28
return output
# 定义判别器
# 判别器的输入:真实图片(1*28*28) / 生成的图片(1*28*28)
# 判别器的输出:判断图片为真的概率(0~1)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Linear(28 * 28, 512), # 输入为28*28的图片,输出为512维的特征
nn.LeakyReLU(), # 激活函数
nn.Linear(512, 256),
nn.LeakyReLU(), # 激活函数
nn.Linear(256, 1),
nn.Sigmoid() # 激活函数
)
def forward(self, input):
input = input.view(-1, 28 * 28) # 将输入的图片reshape为28*28
output = self.main(input)
return output
# 定义损失函数
loss_fn = torch.nn.BCELoss() # 二分类交叉熵损失函数
# 定义优化器
# 生成器的优化器
gen = Generator().to(device)
g_optim = optim.Adam(gen.parameters(), lr=0.0001)
# 判别器的优化器
dis = Discriminator().to(device)
d_optim = optim.Adam(dis.parameters(), lr=0.0001)
# 绘图函数
def plot_img(model, test_input):
prediction = np.squeeze(model(test_input).detach().cpu().numpy())
fig = plt.figure(figsize=(4, 4))
for i in range(16):
plt.subplot(4, 4, i + 1)
plt.imshow((prediction[i] + 1) / 2)
plt.axis('off')
plt.show()
test_input = torch.randn(16, 100, device=device)
D_loss = []
G_loss = []
writer = SummaryWriter(log_dir='runs/logs')
# 训练
for epoch in range(50):
d_epoch_loss = 0
g_epoch_loss = 0
count = len(dataloader)
for step, (img, _) in enumerate(dataloader):
img = img.to(device)
size = img.size(0)
random_noise = torch.randn(size, 100, device=device)
# 训练判别器
d_optim.zero_grad()
# 1.1 真实图片
real_output = dis(img) # 判别器判断真实图片的概率
d_real_loss = loss_fn(real_output, torch.ones_like(real_output)) # 真实图片的损失
d_real_loss.backward() # 反向传播
# 1.2 生成图片
fake_img = gen(random_noise) # 生成图片
fake_output = dis(fake_img.detach()) # 判别器判断生成图片的概率
d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output)) # 生成图片的损失
d_fake_loss.backward() # 反向传播
# 1.3 损失函数
d_loss = d_real_loss + d_fake_loss # 判别器的损失
d_optim.step() # 更新参数
# 训练生成器
g_optim.zero_grad()
fake_output = dis(fake_img) # 判别器判断生成图片的概率
g_loss = loss_fn(fake_output, torch.ones_like(fake_output)) # 生成图片的损失
g_loss.backward() # 反向传播
g_optim.step() # 更新参数
with torch.no_grad():
d_epoch_loss += d_loss.item()
g_epoch_loss += g_loss.item()
with torch.no_grad():
d_epoch_loss /= count
g_epoch_loss /= count
D_loss.append(d_epoch_loss)
G_loss.append(g_epoch_loss)
print('Epoch: {}, Step: {}, D_loss: {:.4f}, G_loss: {:.4f}'.format(epoch, step, d_epoch_loss, g_epoch_loss))
writer.add_scalar('D_loss', d_epoch_loss, epoch)
writer.add_scalar('G_loss', g_epoch_loss, epoch)
plot_img(gen, test_input)
五、实现结果
网络训练过程中,随着epoch的增多,生成器生成的结果越来越真实。具体图像如下所示:
Epoch1 | Epoch10 | Epoch30 | Epoch50 |
---|---|---|---|
D_loss | G_loss |
---|---|
如上图所示,由两个损失曲线可以看出:
对于生成器而言,它的损失一开始较大,在训练过程中呈减小的趋势;
原因分析:生成器是从随机噪声开始学习,因此最初的时候损失会比较大。
对于鉴别器而言,它的损失一开始较小,在训练过程中呈增大的趋势;
原因分析:在最开始的时候,鉴别器很容易判断学得不好的生成器生成的假样本,但生成器在对抗训练中不断改进提升,所以鉴别器鉴别真假样本越来越难。