1.GAN技术的介绍
之前一直以为GAN技术只在CV用到,最近在看无监督机器翻译以及翻译质量研究上,有很多方法基于GAN技术,因此不得不弥补以前的“罪过”。GAN技术也是被业界内称作为无监督学习中最具有前景的方法之一。下面将介绍GAN原理以及基于Pytorch实现GAN生成MNIST数据集。
GAN技术的原理可以抽象为:随机给一“青铜原料”,"古玩制作商"根据这块“青铜原料”将其按照神龙鼎的形状打造成近似的假神龙鼎,而“鉴宝专家”每天应对各种各样真假文物,去鉴定文物的真假。其实这是一场博弈,“古玩制作商”为了骗过专家的法眼,会不断去学习,最终达到以假乱真的境界。以下先给出GAN的架构图:
其中随机数z(一般是是一个低纬向量)经过生成器(Generator)生成一个与数据集样本相同的维度。这一步可以就看"古玩制作商"根据“青铜原料”将其按照神龙鼎的形状打造成近似的假神龙鼎。其中随机数z就是青铜原料,生成器(Generator)是“古玩制作商”。最终在真假样本混合的空间里,我们希望识别器D(discrimination)可以识别出哪些样本是真的,哪些样本是假的。识别器就可以看做是“鉴宝专家”。这样解释过于通俗,以下用数学公式解释:
给定随机数,其中属于任意分布。我们通过经过生成器以后得到。假设生成的分布满足,而原本真实的样本空间分布为:。我们希望生成的样本分布近似满足真实样本空间,即:。下一步就是带入识别器中去识别。
GAN技术的目的就是希望通过生成器生成之前不存在但是又很真实的样本。
这里的优化函数定义也比较意思(是一个二分类的交叉熵):
step1:我们固定G,调整D使得值最大:此时,。这一步其实就是提高判别意识,因为只有专家判别能力越高,生成器造假也要被迫提高。
step2:我们固定D,调整G使得值最小:此时。这一步其实在提高生成器的造假能力,使其骗过专家判别。
这两步在代码实现中是分开进行的,而非一起的,最终就是这种内卷,专家识别能力与生成器造假能力进行对抗平衡(这也是在step1的时候趋向于0,而step2趋向于1,对抗!)
最上式的公式求偏导计算,最终得到最优的结果是:
即,博弈论中的纳什均衡,其中生。
2.GAN技术的实现
GAN技术中的数据生成器和识别器其实是两个学习器,具体学习器模型不定。这里我们为了方便计算,我们均采用简单的前馈神经网络。
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.optim import Adam as Adam
transform = transforms.Compose([
#将PILImage或者numpy的ndarray转化成Tensor,这样才能进行下一步归一化
transforms.ToTensor(),
#transforms.Normalize(mean,std)参数:
transforms.Normalize([0.5], [0.5]),
])
batch=128
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch, shuffle=True)
测试:
for img_real_batch,_ in trainloader:#标签不需要,因为我们不是做图片识别的
img_real_batch=img_real_batch.squeeze().reshape(batch,-1)
print(img_real_batch.shape)
break
结果:
torch.Size([128, 784])
class Generate(nn.Module):
def __init__(self,noise_dimension=50,img_input_size=784):
super(Generate,self).__init__()
self.net=nn.Sequential(
nn.Linear(noise_dimension,256),
nn.ReLU(),
nn.Linear(256,256),
nn.ReLU(),
nn.Linear(256,img_input_size),
nn.Tanh()
)
def forward(self,f):
#f[batch noise_dimension]
x=self.net(f)
return x
class Discriminator(nn.Module):
def __init__(self,INPUT_SIZE=784):
super(Discriminator,self).__init__()
self.input_size=INPUT_SIZE
self.net=nn.Sequential(
nn.Linear(INPUT_SIZE,256),
nn.ReLU(),
nn.Linear(256,256),
nn.ReLU(),
nn.Linear(256,1),
nn.Sigmoid()
)
def forward(self,x):
#x[batch 784]
output=self.net(x)
return output
测试:
noise_dimension=50
input_size=784
device=torch.device("cuda" if torch.cuda.is_available()else "cpu")
G=Generate().to(device)
D=Discriminator().to(device)
f=torch.randn(batch,noise_dimension).to(device)
x=G(f)
print(x.shape)
输出:
torch.Size([128, 784])
output=D(x)
print(output.shape)
输出:
torch.Size([128, 1])
下一步是训练过程,其实GAN技术的核心:
epochs=100
G_optim=Adam(G.parameters(),lr=0.0001)
D_optim=Adam(D.parameters(),lr=0.0001)
criteriom=nn.BCELoss()
for epoch in range(epochs):
G_loss=0
D_loss=0
for real_img,_ in trainloader:
#real_img[batch 1 28 28]
batch=real_img.shape[0]
real_img=real_img.squeeze().reshape(batch,-1).to(device)
#real_img[batch 784]
#随机产生batch个噪声
f=torch.randn(batch,noise_dimension).to(device)
#f[batch noise_dimension]
#造假图像
faker_img=G(f)
#faker_img[batch 784]
#判别器的训练
D_optim.zero_grad()
#真正的图像
output=D(real_img)
label=torch.ones(batch).to(device)
loss_real=criteriom(output,label)
#假的图像
output=D(faker_img.detach())#这里detach是不让其修改G的参数,即固定G。
label=torch.zeros(batch).to(device)
loss_faker=criteriom(output,label)
#交叉熵
loss=loss_real+loss_faker
loss.backward()
#更新梯度提高专家的鉴宝能力
D_optim.step()
D_loss+=loss.item()
#生成器的训练
f=torch.randn(batch,noise_dimension).to(device)
#造假图像
faker_img=G(f)
G_optim.zero_grad()
output=D(faker_img)#此时训练生成器,因此需要反向传播其梯度更新
label=torch.ones(batch).to(device)#这里将生成的标签为真,即欺诈判别器,提高其模仿能力
loss=criteriom(output,label)
loss.backward()
G_optim.step()
G_loss+=loss.item()
print(f"epochs:{epoch+1:03},Generate_loss:{G_loss/len(trainloader):.8f},Discriminator_loss:{D_loss/len(trainloader):.8f}")
if((epoch+1)%10==0):
#画图展示
test=torch.randn(128,50).to(device)
test=G(test).reshape(128,28,28).cpu()
test=(1+test)*0.5
plt.figure()
plt.subplots_adjust(wspace=0,hspace=0)
for i in range(test.shape[0]):
plt.subplot(8,16,i+1)
plt.axis("off")
plt.imshow(test[i].detach(),cmap="gray")
plt.show()
训练结果:
迭代10次:
迭代30次:
迭代50次:
迭代70次:
迭代100次: