GAN学习 ACGAN TensorFlow代码解读_acgan代码

收集整理了一份《2024年最新物联网嵌入式全套学习资料》,初衷也很简单,就是希望能够帮助到想自学提升的朋友。
img
img

如果你需要这些资料,可以戳这里获取

需要这些体系化资料的朋友,可以加我V获取:vip1024c (备注嵌入式)

一个人可以走的很快,但一群人才能走的更远!不论你是正从事IT行业的老鸟或是对IT行业感兴趣的新人

都欢迎加入我们的的圈子(技术交流、学习资源、职场吐槽、大厂内推、面试辅导),让我们一起学习成长!

2.可借鉴的代码块写法


本文是针对MNIST手写数据的 ac_gan_tensorflow.py的代码解读,全文按py代码顺序依次解读,对于理解acgan的基本原理有很大的帮助,可以直接运行,但实际操作应该配合其他的网络架构或者improve technique。

代码解读

line1-6  import需要的库

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np  # 定义矩阵
import matplotlib.pyplot as plt  # 画图
import matplotlib.gridspec as gridspec  # 定义网格
import os

line 9-18 load MNIST数据 & 定义关键参数

mnist = input_data.read_data_sets('../MNIST', one_hot=True)

mb_size = 32
X_dim = mnist.train.images.shape[1]  # 图像维度 [数据量,图像维度]
y_dim = mnist.train.labels.shape[1]  # Label 数据长
z_dim = 10  # 噪音维度
h_dim = 128  # 中间层神经元数
eps = 1e-8  # 定义一个很小的数 +eps 可保证不为0
lr = 1e-3  # 学习率
d_steps = 3  # 没用到

其中,读取MNIST数据时,需要将二进制的MNIST数据download完成放在"MNIST"文件夹中

line 21- 34 定义plot 函数 (可作为代码块之后自己用)

def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')  # cmp:

    return fig

line 37-40 定义xavier 初始化函数

def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
    return tf.random_normal(shape=size, stddev=xavier_stddev)

line 43-45 定义占位符

X = tf.placeholder(tf.float32, shape=[None, X_dim])
y = tf.placeholder(tf.float32, shape=[None, y_dim])
z = tf.placeholder(tf.float32, shape=[None, z_dim])

line 47-66 定义D & G

G_W1 = tf.Variable(xavier_init([z_dim + y_dim, h_dim]))
G_b1 = tf.Variable(tf.zeros(shape=[h_dim]))
G_W2 = tf.Variable(xavier_init([h_dim, X_dim]))
G_b2 = tf.Variable(tf.zeros(shape=[X_dim]))


def generator(z, c):
    inputs = tf.concat(axis=1, values=[z, c])  # ACGAN 把噪音和标签concat作为G的输入
    G_h1 = tf.nn.relu(tf.matmul(inputs, G_W1) + G_b1)
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2  # logit值
    G_prob = tf.nn.sigmoid(G_log_prob)  # 返回图像G_prob
    return G_prob


D_W1 = tf.Variable(xavier_init([X_dim, h_dim]))
D_b1 = tf.Variable(tf.zeros(shape=[h_dim]))
D_W2_gan = tf.Variable(xavier_init([h_dim, 1]))  # 判断0/1的GAN的Weight
D_b2_gan = tf.Variable(tf.zeros(shape=[1]))  # 判断0/1的GAN的Bias
D_W2_aux = tf.Variable(xavier_init([h_dim, y_dim]))  # 判断类别 Weight
D_b2_aux = tf.Variable(tf.zeros(shape=[y_dim]))  # 判断类别 Bias


def discriminator(X):
    D_h1 = tf.nn.relu(tf.matmul(X, D_W1) + D_b1)
    out_gan = tf.nn.sigmoid(tf.matmul(D_h1, D_W2_gan) + D_b2_gan)  # 输出fake/real 的概率值
    out_aux = tf.matmul(D_h1, D_W2_aux) + D_b2_aux  # 输出class 的概率值
    return out_gan, out_aux


theta_G = [G_W1, G_W2, G_b1, G_b2]
theta_D = [D_W1, D_W2_gan, D_W2_aux, D_b1, D_b2_gan, D_b2_aux]

line 80-85 定义随机噪音函数 & 交叉熵函数

def sample_z(m, n):  # 产生随机噪音
    return np.random.uniform(-1., 1., size=[m, n])


def cross_entropy(logit, y):  #定义交叉熵
    return -tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logit, labels=y))

line 88-85 定义随机噪音函数 & 交叉熵函数

G_sample = generator(z, y)  # G产生的fake image

D_real, C_real = discriminator(X) 
D_fake, C_fake = discriminator(G_sample)

# Cross entropy aux loss  # 标签误差
C_loss = cross_entropy(C_real, y) + cross_entropy(C_fake, y)

# GAN D loss
D_loss = tf.reduce_mean(tf.log(D_real + eps) + tf.log(1. - D_fake + eps))  # +eps 加上一个很小的数 防止为0
DC_loss = -(D_loss + C_loss)

# GAN's G loss
G_loss = tf.reduce_mean(tf.log(D_fake + eps))
GC_loss = -(G_loss + C_loss)

D_solver = (tf.train.AdamOptimizer(learning_rate=lr)
            .minimize(DC_loss, var_list=theta_D))
G_solver = (tf.train.AdamOptimizer(learning_rate=lr)
            .minimize(GC_loss, var_list=theta_G))

line 88-108 (ACGAN的关键部分) 定义D&G的loss [DC_loss, GC_loss]

用Adam优化器优化G\D参数

G_sample = generator(z, y)  # G产生的fake image

D_real, C_real = discriminator(X)
D_fake, C_fake = discriminator(G_sample)

# Cross entropy aux loss  # 标签误差
C_loss = cross_entropy(C_real, y) + cross_entropy(C_fake, y)

# GAN D loss
D_loss = tf.reduce_mean(tf.log(D_real + eps) + tf.log(1. - D_fake + eps))  # +eps 加上一个很小的数 防止为0
DC_loss = -(D_loss + C_loss)

# GAN's G loss
G_loss = tf.reduce_mean(tf.log(D_fake + eps))
GC_loss = -(G_loss + C_loss)

D_solver = (tf.train.AdamOptimizer(learning_rate=lr)
            .minimize(DC_loss, var_list=theta_D))
G_solver = (tf.train.AdamOptimizer(learning_rate=lr)
            .minimize(GC_loss, var_list=theta_G))

line118-146 py主体运行部分

or it in range(1000000):
    X_mb, y_mb = mnist.train.next_batch(mb_size)
    z_mb = sample_z(mb_size, z_dim)  # 产生batch的噪音z

    _, DC_loss_curr = sess.run(
        [D_solver, DC_loss],
        feed_dict={X: X_mb, y: y_mb, z: z_mb}
    )

    _, GC_loss_curr = sess.run(
        [G_solver, GC_loss],
        feed_dict={X: X_mb, y: y_mb, z: z_mb}
    )



![img](https://img-blog.csdnimg.cn/img_convert/1e56963ef2023578c0deb0c2bb1a6d17.png)
![img](https://img-blog.csdnimg.cn/img_convert/b3ed7a3d8838aea4e762202561e41b1e.png)

**既有适合小白学习的零基础资料,也有适合3年以上经验的小伙伴深入学习提升的进阶课程,涵盖了95%以上物联网嵌入式知识点,真正体系化!**

**由于文件比较多,这里只是将部分目录截图出来,全套包含大厂面经、学习笔记、源码讲义、实战项目、大纲路线、电子书籍、讲解视频,并且后续会持续更新**

**需要这些体系化资料的朋友,可以加我V获取:vip1024c (备注嵌入式)**

**[如果你需要这些资料,可以戳这里获取](https://bbs.csdn.net/topics/618679757)**

料,也有适合3年以上经验的小伙伴深入学习提升的进阶课程,涵盖了95%以上物联网嵌入式知识点,真正体系化!**

**由于文件比较多,这里只是将部分目录截图出来,全套包含大厂面经、学习笔记、源码讲义、实战项目、大纲路线、电子书籍、讲解视频,并且后续会持续更新**

**需要这些体系化资料的朋友,可以加我V获取:vip1024c (备注嵌入式)**

**[如果你需要这些资料,可以戳这里获取](https://bbs.csdn.net/topics/618679757)**

  • 4
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值