第一片代码model_CT.py
用于G和D的构造
"""
Created on Tue Jul 24 20:33:14 2018
E-mail: Eric2014_Lv@sjtu.edu.cn
@author: DidiLv
"""
import tensorflow as tf
import numpy as np
def conv2d(x, W):
return tf.nn.conv2d(input = x, filter = W, strides = [1,1,1,1], padding = 'SAME')
def avg_pool_2x2(x):
return tf.nn.avg_pool(x, ksize = [1,2,2,1], strides = [1,2,2,1], padding = 'SAME')
def xavier_init(size):
in_dim = size[0]
xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
return tf.random_normal(shape=size, stddev=xavier_stddev)
def sample_z(shape):
return np.random.uniform(-1., 1., size=shape)
def discriminator(x_image, reuse=False):
with tf.variable_scope('discriminator') as scope:
if (reuse):
tf.get_variable_scope().reuse_variables()
W_conv1 = tf.get_variable('d_wconv1', shape = [5, 5, 1, 8], initializer=tf.truncated_normal_initializer(stddev=0.02))
b_conv1 = tf.get_variable('d_bconv1', shape = [8], initializer=tf.constant_initializer(0))
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = avg_pool_2x2(h_conv1)
W_conv2 = tf.get_variable('d_wconv2', shape = [5, 5, 8, 16], initializer=tf.truncated_normal_initializer(stddev=0.02))
b_conv2 = tf.get_variable('d_bconv2', shape = [16], initializer=tf.constant_initializer(0))
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = avg_pool_2x2(h_conv2)
W_conv3 = tf.get_variable('d_wconv3', shape = [5, 5, 16, 32], initializer=tf.truncated_normal_initializer(stddev=0.02))
b_conv3 = tf.get_variable('d_bconv3', shape = [32], initializer=tf.constant_initializer(0))
h_conv3 = tf.nn.relu(conv2d(h_pool2, W_conv3) + b_conv3)
h_pool3 = avg_pool_2x2(h_conv3)
W_conv4 = tf.get_variable('d_wconv4', shape = [5, 5, 32, 64], initializer=tf.truncated_normal_initializer(stddev=0.02))
b_conv4 = tf.get_variable('d_bconv4', shape = [64], initializer=tf.constant_initializer(0))
h_conv4 = tf.nn.relu(conv2d(h_pool3, W_conv4) + b_conv4)
h_pool4 = avg_pool_2x2(h_conv4)
W_fc1 = tf.get_variable('d_wfc1', [14 * 12 * 64, 320], initializer=tf.truncated_normal_initializer(stddev=0.02))
b_fc1 = tf.get_variable('d_bfc1', [320], initializer=tf.constant_initializer(0))
h_pool4_flat = tf.reshape(h_pool4, [-1, 14 * 12