GAN的由来和故事理解
GAN是如今非常火的生成式算法的核心,由Lan Goodfellow,Yoshua Bengio等人在2014年提出。大神Lan Goodfellow在2022年的一件事很有意思,因为不满苹果公司不给居家办公而离职跳槽谷歌。只要人越牛,只有你自己才能约束你自己。比起大神Lan Goodfellow这个人,他所提出的GAN的思想更有意思。
形象的说明GAN网络的训练过程就是:
就是画师(G网络)和鉴画师(D网络)的故事。画师从小立志画出能够以假乱真的世界上的各大著名画作,试图骗过鉴画师并借此成为世界上最为富有的人。鉴画师从小立志成为世界上最牛的鉴画师,能够鉴别出投机取巧做假画的人。而画师和鉴画师这两人从小就杠上了。他们俩在各自在同一套数据集中修炼,画师学着画画。而鉴画师,学着将真画判别为真。每过一段时间,画师和鉴画师都要battle一下。画师将他画的给鉴画师看,鉴画师试图学习判别出真画和假画的区别。随着时间推移画师水平越来越高,鉴画师的水准也越来越高,最终他们俩都成为自己领域的大师。
这里附上GAN网络的原文:
https://arxiv.org/abs/1406.2661
附上GAN网络别人开源代码:https://github.com/eriklindernoren/PyTorch-GAN.git
GAN的数学建模
GAN的故事很有意思,也很简单。为更好地分析问题,需要数学将这个故事建模表示出来。
这个就是故事建模成的数学模型。接下来,我们理解下这个模型
L(D,G)表示整个模型训练的损失是关于鉴画师网络D和画师网络G 的函数
Pr(x)表示分布是真实图像x的分布。
而z是输入到G函数的分布,z是一个高斯分布(通常均值为0,方差为1,维度为[batch,dimz]),dimz可以自定义。画师网络G需要输入这样一个噪声分布,G才能依据这个分布生成对应的画。
第一行为G网络的损失,更新G网络参数时,D网络参数固定
第二行为D网络的损失,更新D网络参数时,G网络参数固定
注意公式中,真实图像输入x,是一个数据集中随机的图片,若图片是BMP图,二维的图,x是一个二维的分布。而公式中对D(x)的输出是一维的分布。但落到具体某一张图输入时,D(x)输出是一个确定数(数值大表示真实图置信度高,反之置信度小,最大为1)。而公式中对一个D(x)输出的分布求对数再求期望,是信息论计算熵的基本操作。(若logD(x)表示信息量,熵就是平均信息量)。而这里可以理解为平均置信度。
公式的代码转换
计算G网络损失
希望D网络判断G网络的输出图片为真。
其中,valid是维度为[batch,1]的全1矩阵,损失为二元交叉熵。
损失越小,生成图片经过判别器输出越接近全1矩阵。
计算D网络损失
希望D网络判断真实图片为真,生成图片为假。损失代码如下,
注意下detach操作,这里gen_imgs=D(z),由于G网络包含梯度信息,detach操作让计算图计算梯度信息时,不考虑G网络参数。正如上面所述更新G或D网络参数时,是不要更新另一个网络参数信息的,当成常数就好。detach操作可以减少对G网络梯度信息计算,减少计算量。
代码测试
https://github.com/eriklindernoren/PyTorch-GAN.git
获取git下的implementations文件下的网络有很多,其中大多都能直接运行,测试。
总结
GAN的原理似乎非常的简单,真正天才般、颠覆性的想法往往大道至简,简到让大家都觉得自己能够想出来。然而灵感的产生可能就是平时思考的一念之间的差距,却是难以跨越的鸿沟。