GAN神经网络的keras实现

本文介绍了GAN的原理和Keras实现,包括Generator和Discriminator的相互作用,重点讨论了K值和输入噪音维度对实验结果的影响,展示了一个简单的GAN模型结构和训练过程。
摘要由CSDN通过智能技术生成

具体原理可以参考上面的文献,不过在这里还是大概讲一下。
其实GAN的原理非常简单,它有2个子网络组成,一个是Generator,即生成网络,它以噪音样本为输入,通过学习到的权重,把噪音转变(即生成)为有意义的信号;一个是Discriminator,即判别网络,他以信号为输入(可以来自generator生成的信号,也可以是真正的信号),通过学习来判别信号的真假,并输出一个0-1之间的概率。可以把Generator比喻为一个假的印钞机,而Discriminator则是验钞机,他们两个互相竞争,使得印钞机越来越真,同时验钞机也越来越准。但是最终我们是希望Generator越来越真,而Discriminator的输出都是0.5,即难以分辨~~

而在训练的时候,则分两个阶段进行,第一个阶段是Discriminator的学习,此时固定Generator的权重不变,只更新Discriminator的权重。loss函数是:

 

1mi=1m[logD(xi)+log(1D(G(zi)))]

 

其中m是batch_size, x表示真正的信号,z表示噪音样本。训练时分别从噪音分布和真实分布中选出m个噪音输入样本和m个真实信号样本,通过对以上的loss function最大化更新Discriminator的权重

第二个阶段是对Generator进行训练,此时的loss function是:

 

1mi=1m[log(1D(G(zi)))]

 

不过,此时是对loss最小化来更新Generator的权重。

另外,这2个阶段并不是交替进行的,而是执行K次Discriminator的更新,再执行1次Generator的更新。
后面的实验结果也显示,K的选择非常关键。

具体实现

主要工具是 python + keras,用keras实现一些常用的网络特别容易,比如MLP、word2vec、LeNet、lstm等等,github上都有详细demo。但是稍微复杂些的就要费些时间自己写了。不过整体看,依然比用原生tf写要方便。而且,我们还可以把keras当初是学习tf的参考代码,里面很多写法都非常值得借鉴。

废话不多说了,直接上代码吧:

GANmodel

只列出最主要的代码


# 这是针对GAN特殊设计的loss function
def log_loss_discriminator(y_true, y_pred): return - K.log(K.maximum(K.epsilon(), y_pred)) def log_loss_generator(y_true, y_pred): return K.log(K.maximum(K.epsilon(), 1. - y_pred)) class GANModel: def __init__(self, input_dim, log_dir = None): ''' __tensor[0]: 定义了discriminateor的表达式, 对y进行判别,true samples __tensor[1]: 定义了generator的表达式, 对x进行生成,noise samples ''' if isinstance(input_dim, list): input_dim_y, input_dim_x = input_dim[0], input_dim[1] elif isinstance(input_dim, int): input_dim_x = input_dim_y = input_dim else: raise ValueError("input_dim should be list or interger, got %r" % input_dim) # 必须使用名字,方便后面分别输入2个信号 self.__inputs = www.97yingyuan.org [layers.Input(shape=(input_dim_y,), name = "y"),<
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值