本文转自:https://www.zhihu.com/question/63493495/answer/212218657
其实 GAN 网络的巧妙在于其设计思维,而技术上是对现有算法的组合,没什么神秘的。既然题主对其大致意思已经了解了,那我就举一个构建实例。GAN 网络主要由两个网络合成。
生成网络:
输入为随机数,输出为生成数据。比如说,输入一个一维随机数,输出一张 28x28 (784 维) 的图片 (MNIST)。
网络实现用最 vanilla 的多层神经网络即可。记得不宜超过三层,否则梯度消失梯度爆炸的问题你懂的。中间的激活函数用当下最时兴的 Relu 就好。输出层需要使用其他激活函数,目的是为了生成数据的取值范围与真实数据相似,具体使用什么函数视情况而定。下面给出一个可能的实现方案:
(?, 1) (1, 128)
relu
(128, 128)
relu
(128, 784)
tanh
判别网络:
现在,我们把生成网络生成的数据称为假数据,对应的,来自真实数据集的数据称为真数据。判别网络输入为数据(真或假),输出一个判别概率。需注意的是,这里判别的是图像的真伪,而非图像的类别。还以 MNIST 为例。输入一个图片后,我们并不要认图片上画的是啥数字,而是判别图像到底来自于真实数据集,还是生成网络的胡乱合成。所以输出一个一维条件概率(伯努利分布的概率参数)就好了。
网络实现同样可用最基本的多层神经网络。下面给出一个可能的方案。
(?, 784) (784, 128)
relu
(128, 128)
relu
(128, 1)
sigmoid
loss 函数:
既然有俩网络,那么我们就有俩 loss 函数对应之。生成网络用
代表 Generative;
代表 Discriminative;
代表交叉熵,这也是常用算法之一,如果题主对于其意义有何不解,网上有大把资料。
是输入生成网络的随机数,那么
就是生成网络合成的假数据,
则是对这个假数据的判别概率。这个 loss 用大白话来说,我生成网络的目标就是要你判别网络觉得我合成的数据是真的!(概率
)
判别网络的 loss 函数用
为真实数据。这个 loss 说的是,我判别网络就是要将真数据归为真,假数据归为假,既不想放过一个假数据,也不想错杀一个真数据。
可见,这两个 loss 的定义非常直觉化。对抗这个称呼就是这么来的。
训练:
训练我们用两步走,先优化一次 再优化一次
,如此往复直到题主满意。两步走的训练算法与 Goodfellow 最初论文中的算法不太一样,不过结果是基本“等价”的。
超参:
GAN 网络对超参的敏感是众所周知的。上面提供的超参绝对不能保证能生成令人满意的结果。我只是拍脑袋想的。。。但是,题主绝对可以得到一些启发性的结果让自己对 GAN 网络有进一步的了解。
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------
关于SRGAN的的论文中文翻译网上一大堆,可以直接读网络模型(大概了解),关于loss的理解,然后就能跑代码
loss = mse + 对抗损失 + 感知损失 : https://blog.csdn.net/DuinoDu/article/details/78819344
具体Pytorch代码可参考文档:http://www.mamicode.com/info-detail-2737094.html