在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 train.py,同时新建文件夹 logs 和文件夹 samples,前者用来保存训练过程中的日志和模型,后者用来保存训练过程中采样器的采样图片,在 train.py 中输入如下代码:
# -*- coding: utf-8 -*- import tensorflow as tf import os from read_data import * from utils import * from ops import * from model import * from model import BATCH_SIZE def train(): # 设置 global_step ,用来记录训练过程中的 step global_step = tf.Variable(0, name = 'global_step', trainable = False) # 训练过程中的日志保存文件 train_dir = '/home/your_name/TensorFlow/DCGAN/logs' # 放置三个 placeholder,y 表示约束条件,images 表示送入判别器的图片, # z 表示随机噪声 y= tf.placeholder(tf.float32, [BATCH_SIZE, 10], name='y') images = tf.placeholder(tf.float32, [64, 28, 28, 1], name='real_images') z = tf.placeholder(tf.float32, [None, 100], name='z') # 由生成器生成图像 G G = generator(z, y) # 真实图像送入判别器 D, D_logits = discriminator(images, y) # 采样器采样图像 samples = sampler(z, y) # 生成图像送入判别器 D_, D_logits_ = discriminator(G, y, reuse = True) # 损失计算 d_loss_real = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(D_logits, tf.ones_like(D))) d_loss_fake = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(D_logits_, tf.zeros_like(D_))) d_loss = d_loss_real + d_loss_fake g_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(D_logits_, tf.ones_like(D_))) # 总结操作 z_sum = tf.histogram_summary("z", z) d_sum = tf.histogram_summary("