tensorflow 交替训练
tf.compat.v1.Session
定义Session:
sess = tf.compat.v1.Session()
tf.compat.v1.get_collection
- 获取构建的静态图网络中的所有变量list:
variables =tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)
for i in variables:
print(i)
- 获取指定scope中的所有变量list:
train_var = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope='output')
for i in train_var:
print(i)
tf.compat.v1.name_scope
(1)在某个tf.compat.v1.name_scope()指定的区域中定义的所有对象及各种操作,他们的“name”属性上会增加该命名区的区域名,用以区别对象属于哪个区域;
(2)将不同的对象及操作放在由tf.name_scope()指定的区域中,便于在tensorboard中展示清晰的逻辑关系图,这点在复杂关系图中特别重要。
with tf.compat.v1.name_scope("generator"):
sample = tf.compat.v1.Variable(tf.random.normal([self.n, self.m]), name='sample')
......
tf.compat.v1.variable_scope
Tensorflow提供了Variable_Scope机制来共享变量。
with tf.compat.v1.variable_scope("discriminate", reuse = tf.compat.v1.AUTO_REUSE):
wo = tf.compat.v1.get_variable(shape = [dims[i], dims[i + 1]],name='w_discriminate')
......
tf.compat.v1.train.Optimizer
Tensorflow的Optimizer类。
opt = tf.compat.v1.train.RMSPropOptimizer(lr, decay)
opt.minimize
训练目标,通常为最小化loss。
global_step = tf.compat.v1.Variable(0, trainable=False,name = 'global_step')
train_step = opt.minimize(net.loss,global_step=global_step,var_list=train_vars)
var_list,顾名思义,即训练时,要训练的变量列表list,这可以通过tf.compat.v1.get_collection获得,并且因为list的特性,list = list1 + list2
tf.compat.v1.global_variables_initializer
全局变量初始化,Tensorflow只能使用初始化后的变量进行训练。
sess.run(tf.compat.v1.global_variables_initializer())
sess.run(train_step, feed_dict={…})
梯度更新。
sess.run(train_step, feed_dict={...})
示例代码
''' Session '''
sess = tf.compat.v1.Session()
''' Define Net '''
net = Net()
''' Set up val_list '''
first_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope='first')
second_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope='second')
all_vars = first_vars + second_vars
''' Set up optimizer '''
opt = tf.compat.v1.train.RMSPropOptimizer(lr, decay)
''' Minimize '''
global_first = tf.compat.v1.Variable(0, trainable=False,name = 'global_first')
train_first = opt.minimize(net.loss,global_step=global_first,var_list=first_vars)
global_second = tf.compat.v1.Variable(0, trainable=False,name = 'global_second')
train_second = opt.minimize(net.loss,global_step=global_second,var_list=second_vars)
global_all = tf.compat.v1.Variable(0, trainable=False,name = 'global_all')
train_all = opt.minimize(net.loss,global_step=global_all,var_list=all_vars)
''' Initialize TensorFlow variables '''
sess.run(tf.compat.v1.global_variables_initializer())
''' Optimize '''
sess.run(train_first, feed_dict={...})
sess.run(train_second, feed_dict={...})
sess.run(train_all, feed_dict={...})