InfoGAN详解与实现(采用tensorflow2(1)

本文详细介绍了InfoGAN(信息最大生成对抗网络)的工作原理,包括如何通过辅助分布Q(c|x)估计互信息的下限,以及生成器和鉴别器的损失函数。InfoGAN在MNIST数据集上的应用能够学习到离散和连续编码,以控制生成器的输出属性。文章还提供了基于TensorFlow2的InfoGAN模型构建和训练的代码实现。
摘要由CSDN通过智能技术生成

但是由于估计 H ( c ∣ G ( z , c ) ) H(c|G(z,c)) H(c∣G(z,c))需要后验分布 p ( c ∣ G ( z , c ) ) = p ( c ∣ x ) p(c|G(z,c))=p(c|x) p(c∣G(z,c))=p(c∣x),因此难以估算 H ( c ∣ G ( z , c ) ) H(c|G(z,c)) H(c∣G(z,c))。

解决方法是通过使用辅助分布 Q ( c ∣ x ) Q(c|x) Q(c∣x)估计后验概率来估计互信息的下限,估计相互信息的下限为:

I ( c ; G ( z , c ) ) ≥ L I ( G , Q ) = E c ∼ p ( c ) , x ∼ G ( z , c ) [ l o g Q ( c ∣ x ) ] + H ( c ) I(c;G(z,c)) \ge L_I(G,Q)=E_{c \sim p©,x \sim G(z,c)}[logQ(c|x)]+H© I(c;G(z,c))≥LI​(G,Q)=Ec∼p©,x∼G(z,c)​[logQ(c∣x)]+H©

在InfoGAN中,假设 H ( c ) H© H©为常数。因此,使互信息最大化是使期望最大化的问题。生成器必须确信已生成具有特定属性的输出。此期望的最大值为零。因此,互信息的下限的最大值为 H ( c ) H© H©。在InfoGAN中,离散潜在编码 Q ( c ∣ x ) Q(c|x) Q(c∣x)的可以用softmax表示。期望是tf.keras中的负categorical_crossentropy损失。

对于一维连续编码,期望是 c c c和 x x x上的二重积分,这是由于期望样本同时来自分离编码分布和生成器分布。估计期望值的一种方法是通过假设样本是连续数据的良好度量。因此,损失估计为 c l o g Q ( c ∣ x ) clogQ(c|x) clogQ(c∣x)。

为了完成InfoGAN的网络,应该有一个 l o g Q ( c ∣ x ) logQ(c|x) logQ(c∣x)的实现。为简单起见,网络Q是附加到鉴别器的辅助网络。

InfoGAN网络架构鉴别器损失函数

L ( D ) = − E x ∼ p d a t a l o g D ( x ) − E z , c l o g [ 1 − D ( G ( z , c ) ) ] − λ I ( c ; G ( z , c ) ) \mathcal L^{(D)} = -\mathbb E_{x\sim p_{data}}logD(x)-\mathbb E_{z,c}log[1 − D(G(z,c))]-\lambda I(c;G(z,c)) L(D)=−Ex∼pdata​​logD(x)−Ez,c​log[1−D(G(z,c))]−λI(c;G(z,c))

生成器损失函数:

L ( G ) = − E z , c l o g D ( G ( z , c ) ) − λ I ( c ; G ( z , c ) ) \mathcal L^{(G)} = -\mathbb E_{z,c}logD(G(z,c))-\lambda I(c;G(z,c)) L(G)=−Ez,c​logD(G(z,c))−λI(c;G(z,c))

其中 λ \lambda λ是正的常数

InfoGAN实现


如果将其应用于MNIST数据集,InfoGAN可以学习分离的离散编码和连续编码,以修改生成器输出属性。 例如,像CGAN和ACGAN一样,将使用10维独热标签形式的离散编码来指定要生成的数字。但是,可以添加两个连续的编码,一个用于控制书写样式的角度,另一个用于调整笔划宽度。保留较小尺寸的编码以表示所有其他属性:

MNIST数据集编码形式

导入必要库

import tensorflow as tf

import numpy as np

from tensorflow import keras

import os

from matplotlib import pyplot as plt

import math

from PIL import Image

from tensorflow.keras import backend as K

生成器

def generator(inputs,image_size,activation=‘sigmoid’,labels=None,codes=None):

“”"generator model

Arguments:

inputs (layer): input layer of generator

image_s

  • 5
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值