简介
生成对抗网络(Generative adversarial networks)是深度学习领域的一个重要生成模型,当然还有其他的生成模型,比如VAE和其他GAN变种模型 。
为什么叫做生成对抗网络。是因为GAN的主要结构包括一个生成器G(Generator)和一个判别器D(Discriminator)。
下面以生产图片为例进行分析
生成网络
接收一个随机的噪音数据(一般服从正态分布),生成图片,记作G(Z)。Z表示噪声数据。这些数据我们可以随机生成,一般符合高斯(正态分布)
判别网络
- 判断真实图像的输出结果
输入为真实数据X,输出X为真实图片的概率(0-1),记作D(X)。为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。 - 判断生成图像的输出结果
输入为生成器生成的图片G(Z),输出为1或0(真或假),记作D(G(Z)
优化目标:
1.让生成器生成的图片G(Z)尽可能为真,骗过判别器。
2. 让判别器D提高精确度,将真实图片判为真,将生成的图片判为假
生成网络:
l
o
g
l
o
g
(
D
(
G
(
z
)
)
loglog(D(G(z))
loglog(D(G(z))越接近1越好,即生成的图片被判别为真实的
判别网络:
l
o
g
(
D
(
x
)
)
+
l
o
g
(
1
−
D
(
G
(
z
)
)
)
log(D(x))+log(1-D(G(z)))
log(D(x))+log(1−D(G(z)))
loglog(D(G(z))越接近0越好,即判别器越强大能识别出生成的图片为假
上面这两个目标看似是矛盾的,这也解释了为什么叫做生成对抗网络。这样通过不断的进行多轮训练、“对抗|, 使得最后我们生成的图片可以“以假乱真”,通过判别网络判别为真。最终理想情况下, G 生成的数据与真实数据非常接近,分布也相同,而 D 无论输出真实数据还是 G 生成的数据都输出0.5。
下面以minist数据集进行训练
import torch.autograd
import torch.nn as nn
from torch.autograd import Variable
from torchvision import transforms
from torchvision import datasets
from torchvision.utils import save_image
import os
if not os.path.exists('./img'):
os.mkdir('./img')
def to_img(x):
out = 0.5 * (x + 1)
out = out.clamp(0, 1) # Clamp函数可以将随机变化的数值限制在一个给定的区间[min, max]内:
out = out.view(-1, 1, 28, 28)
return out
batch_size = 128
epochs = 100 #跑100轮
z_dimension = 100 #噪声数据的维度 输入一个100维的0~1之间的高斯分布
# 图形预处理
transform = transforms.Compose([
transforms.ToTensor(), # 转换PIL.Image or numpy.ndarray
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 使用 mnist数据集
mnist = datasets.MNIST(root='./data/', train=True, transform=transform, download=True)
# 数据载入
dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=batch_size, shuffle=True)
#定义判别器
#这里只是用的简单的几层网络 也可以用卷积神经网络convd进行判别
class discriminator(nn.Module):
def __init__(self):
super(discriminator ,self).__init__()
self.dis = nn.Sequential(
nn.Linear(784 ,256) , # minist数据集为28*28的灰度图,所以输入特征数为28*28*1=784,输出为256
nn.LeakyReLU(0.2) , # 进行非线性映射
nn.Linear(256 ,256) , # 进行一个线性映射
nn.LeakyReLU(0.2),
nn.Linear(256 ,1),
nn.Sigmoid( ) #激活函数,二分类问题中,将实数映射到[0,1],作为概率值,
# 多分类用softmax函数
)
def forward(self, x):
x=self.dis (x)
return x
##定义生成器 Generator
class generator(nn.Module):
def __init__(self):
super(generator,self).__init__()
self.gen=nn.Sequential(
nn.Linear(100,256),# 将一个100维的0~1之间的高斯分布的噪声数据映射到256维,
nn.ReLU(True), # relu激活
nn.Linear(256,256),# 线性变换
nn.ReLU(True), # relu激活
nn.Linear(256,784),# 线性变换
nn.Tanh()# Tanh激活 使得生成数据分布在【-1,1】之间
)
def forward(self, x):
x = self.gen(x)
return x
# 实例化
D = discriminator()
G = generator()
#如果由英伟达的GPU 就用GPU处理,生成图像更快
if torch.cuda.is_available():
D = D.cuda()
G = G.cuda()
##判别器训练
#两部分:1、真实图像
# 2、生成的图像
# 生成网络的参数不更新,始终用同一套参数
# 定义损失函数loss(二分类的交叉熵)
loss = nn.BCELoss() # 是单目标二分类交叉熵函数
optimizer_d = torch.optim.Adam(D.parameters(), lr=0.0003) #学习率选择0.0003
optimizer_g = torch.optim.Adam(G.parameters(), lr=0.0003)
#训练判别器
for epoch in range(epochs): # 进行100个epoch的训练
for i, (img, _) in enumerate(dataloader):
num_img = img.size(0) #pytorch里特征的形式是[bs,channel,h,w],所以img.size(0)就是batchsize(每一个bath的数量)
img = img.view(num_img, -1) # 将图片展开为28*28=784
real_img = Variable(img).cuda() # 将tensor包进Variable
real_label = Variable(torch.ones(num_img)).cuda() # 定义真实的图片label为1
fake_label = Variable(torch.zeros(num_img)).cuda() # 定义假的图片的label为0
#判别真实图片
real_out = D(real_img)
d_loss_real = loss(real_out, real_label) # 得到判别真实图片的loss
real_scores = real_out # 得到真实图片的判别值,输出的值越接近1越好,说明越判断正确了
# 判别生成的假的图片
z = Variable(torch.randn(num_img, z_dimension)).cuda() # 先随机生成一些高斯噪声数组
fake_img = G(z) # 随机噪声放入生成网络中,生成一张假的图片
fake_out = D(fake_img) # 判别器判断假的图片
d_loss_fake = loss(fake_out, fake_label) # 得到假的图片的loss
fake_scores = fake_out # 对于判别器来说,假图片的损失越接近0越好
# 损失函数和优化
d_loss = d_loss_real + d_loss_fake # 损失包括两部分 log(D(x))+log(1-D(G(z)))
optimizer_d.zero_grad() # 在反向传播之前,先将梯度归0
d_loss.backward() # 将误差反向传播
optimizer_d.step() # 更新参数
#训练生成网络
# 将假的图片传入判别器的结果与真实的label对应, 反向传播更新的参数是生成网络里面的参数,这样更新生成网络里面的参数来训练网络,使得生成的图片让判别器判为真的
# 计算假的图片的损失
z = Variable(torch.randn(num_img, z_dimension)).cuda() # 得到随机噪声
fake_img = G(z) # 随机噪声丢入生成器中,生成假的图片
output = D(fake_img)
g_loss = loss(output, real_label) # 假的图片与真实的图片的label的loss
# 反向传播和参数更新
optimizer_g.zero_grad() # 梯度归0
g_loss.backward()
optimizer_g.step()
# 打印中间的损失
if (i + 1) % 100 == 0:
print('Epoch[{}/{}],d_loss:{:.6f},g_loss:{:.6f} '
'D real: {:.6f},D fake: {:.6f}'.format(
epoch, epochs, d_loss.data[0], g_loss.data[0],
real_scores.data.mean(), fake_scores.data.mean() # 打印的是真实图片的损失均值
))
if epoch == 0:
real_images = to_img(real_img.cpu().data)
save_image(real_images, './img/real_images.png')
fake_images = to_img(real_img.cpu().data)
save_image(fake_images, './img/fake_images_{}.png'.format(epoch + 1))
# 保存模型
torch.save(G.state_dict(), './generator.pth')
torch.save(D.state_dict(), './discriminator.pth')