本文是 GAN 在 MNIST 数据集上生成假的手写数字图片的一个实例,具体是用 pytorch 实现的。
先来看下训练的结果,下面两张图中,上面的是真实手写数字图片,下面的是在训练了30个 epoch 之后 GAN 生成的假图片。整体来说,效果还是蛮不错的。
GAN 由一个生成器和一个对抗器组成,在该任务中,生成器的输入是一堆随机生成的噪音,其输出为生成的假图片,其目标是让生成的假图片尽可能的像真实图片;而判别器的输入是一张图片,其输出是这张图片是真实图片的概率。
在训练时,需要先训练判别器再训练生成器,如果判别器的好坏决定着生成器的效果。在生成器训练时,先将一堆噪音输入到生成器并得到假图片,然后再将假图片输入到判别器进行判别,然后将判别结果与真实标签(注意不是假标签,因为生成器的目标是尽可能的模拟真实图片)进行比对形成损失函数。判别器的训练分为两个部分,一是对真实图片进行判别,二是对假图片进行判别,得到的判别结果分别与真实标签和假标签对比形成真实图片的损失和假图片的损失,两者相加就是判别器的总损失。
这个代码有个蛋疼的地方是,如果用 keras 引入 MNIST 数据集,则在训练时损失函数会很快就趋近于 0 了,训练效果很差,而用 torchvision 时则没这个问题。
import os
import cv2
import torch
import torchvision
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets
from keras.datasets import mnist
from torchvision.utils import save_image
batch_size = 100
epoch_num = 30
lr = 0.0002
input_dim = 100
class Generator(nn.Module):
def __init__(self, input_dim):
super(Generator, self).__init__()
self.fc1 = nn.Linear(input_dim, 56 * 56)
self.br = nn.Sequential(
nn.BatchNorm2d(1),
nn.ReLU(True) # inplace设为True,让操作在原地进行
)
self.conv1 = nn.Sequential(
nn.Conv2d(1, 50, 3, stride=1, padding=1),
nn.BatchNorm2d(50),
nn.ReLU(True)
)
self.conv2 = nn.Sequential(
nn.Conv2d(50, 25, 3, stride=1, padding=1),
nn.BatchNorm2d(25),
nn.ReLU(True)
)
self.conv3 = nn.Sequential(
nn.Conv2d(25, 1, 2, stride=2),
nn.Tanh()
)
def forward(self, x):
x = self.fc1(x)
x = x.view(-1, 1, 56, 56)
x = self.br(x)
x = self.conv1(x)
x = self.conv2(x)
output = self.conv3(x)
return output
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(1, 32, 5, stride=1, padding=2),
nn.LeakyReLU(0.2,True)
)
self.pl1 = nn.AvgPool2d(2, stride=2)
self.conv2 = nn.Sequential(
nn.Conv2d(32, 64, 5, stride=1, padding=2),
nn.LeakyReLU(0.2,True)
)
self.pl2 = nn.AvgPool2d(2, stride=2)
self.fc1 = nn.Sequential(
nn.Linear(64 * 7 * 7, 1024),
nn.LeakyReLU(0.2,True)
)
self.fc2 = nn.Sequential(
nn.Linear(1024, 1),
nn.Sigmoid()
)
def forward(self, x):
x = self.conv1(x)
x = self.pl1(x)
x = self.conv2(x)
x = self.pl2(x)
x = x.view(x.shape[0], -1)
x = self.fc1(x)
output = self.fc2(x)
return output
def G_train(input_dim):
G_optimizer.zero_grad()
noise = torch.randn(batch_size, input_dim).to(device)
real_label = torch.ones(batch_size).to(device)
fake_img = G(noise)
D_output = D(fake_img)
G_loss = criterion(D_output, real_label)
G_loss.backward()
G_optimizer.step()
return G_loss.data.item()
def D_train(real_img, input_dim):
D_optimizer.zero_grad()
real_label = torch.ones(real_img.shape[0]).to(device)
D_output = D(real_img)
D_real_loss = criterion(D_output, real_label)
noise = torch.randn(batch_size, input_dim, requires_grad=False).to(device)
fake_label = torch.zeros(batch_size).to(device)
fake_img = G(noise)
D_output = D(fake_img)
D_fake_loss = criterion(D_output, fake_label)
D_loss = D_real_loss + D_fake_loss
D_loss.backward()
D_optimizer.step()
return D_loss.data.item()
def save_img(img, img_name):
img = 0.5 * (img + 1)
img = img.clamp(0, 1)
save_image(img, "./imgs/" + img_name)
# print("image has saved.")
if __name__ == "__main__":
if not os.path.exists("./checkpoint"):
os.makedirs("./checkpoint")
if not os.path.exists("./imgs"):
os.makedirs("./imgs")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载数据
train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=torchvision.transforms.ToTensor(),
download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
# 构建生成器和判别器网络
if os.path.exists('./checkpoint/Generator.pkl') and os.path.exists('./checkpoint/Discriminator.pkl'):
G=torch.load("./checkpoint/Generator.pkl").to(device)
D=torch.load("./checkpoint/Discriminator.pkl").to(device)
else:
G = Generator(input_dim).to(device)
D = Discriminator().to(device)
# 指明损失函数和优化器
criterion = nn.BCELoss()
G_optimizer = optim.Adam(G.parameters(), lr=lr)
D_optimizer = optim.Adam(D.parameters(), lr=lr)
print("Training...........")
for epoch in range(1, epoch_num + 1):
print("epoch: ", epoch)
for batch, (x, _) in enumerate(train_loader):
# 对判别器和生成器分别进行训练,注意顺序不能反
D_loss=D_train(x.to(device), input_dim)
G_loss=G_train(input_dim)
#if batch % 20 == 0:
print("[ %d / %d ] g_loss: %.6f d_loss: %.6f" % (batch, 600, float(G_loss), float(D_loss)))
if batch % 50 == 0:
fake_img = torch.randn(128, input_dim)
fake_img = G(fake_img)
save_img(fake_img, "img_" + str(epoch) + "_" + str(batch) + ".png")
# 保存模型
torch.save(G, "./checkpoint/Generator.pkl")
torch.save(D, "./checkpoint/Discriminator.pkl")