1 引言
没想到第二天就磕到了模型,肝起来吧,前面导入的包就不赘述了,主要是记录自己学习这些代码的过程和学习到的知识点,下面开始:
2 init()中的内容前一部分
part 1 构建计算图
# Build the computational graph
print('Building computational graph ...')
self.G_global_step = tf.Variable(0, trainable=False)
self.D_global_step = tf.Variable(0, trainable=False)
self.handle = tf.placeholder(tf.string, shape=[])
self.training_phase = tf.placeholder(tf.bool)
global_step在滑动平均、优化器、指数衰减学习率等方面都有用到,这个变量的实际意义非常好理解:代表全局步数,比如在多少步该进行什么操作,现在神经网络训练到多少轮等等,类似于一个钟表。
原文地址:TensorFlow中global_step的简单分析
tf.Variable()
中,第一项是G_global_step/D_global_step的初始值为0,trainable=False可以防止该变量被数据流图的收集, 这样我们就不会在训练的时候尝试更新它的值。
接着定义了一个handle(暂时还不清楚它的作用)
placeholder()函数是在神经网络构建graph的时候在模型中的占位,此时并没有把要输入的数据传入模型,它只会分配必要的内存。等建立session,在会话中,运行模型的时候通过feed_dict()函数向占位符喂入数据。
原文地址:tf.placeholder函数说明
本项目中tf.string表示handle的数据类型,以下是官方对tf.string的定义
tf.string:Variable length byte arrays. Each element of a Tensor is a byte array。
shape[]在这里没有限制维数,默认是None,就是一维值,也可以是多维(比如[2,3], [None, 3]表示列是3,行不定),这里应该是行列都不定。
最后定义了一个train_process(暂时还不清楚它的作用,可能是表示训练的2种状态用gan和cgan)使用的是布尔类型
part 2 数据处理
# >>> Data handling
self.path_placeholder = tf.placeholder(paths.dtype, paths.shape)
self.test_path_placeholder = tf.placeholder(paths.dtype)
self.semantic_map_path_placeholder = tf.placeholder(paths.dtype, paths.shape)
self.test_semantic_map_path_placeholder = tf.placeholder(paths.dtype)
train_dataset = Data.load_dataset(self.path_placeholder,
config.batch_size,
augment=False,
training_dataset=dataset,
use_conditional_GAN=config.use_conditional_GAN,
semantic_map_paths=self.semantic_map_path_placeholder)
test_dataset = Data.load_dataset(self.test_path_placeholder,
config.batch_size,
augment=False,
training_dataset=dataset,
use_conditional_GAN=config.use_conditional_GAN,