共享变量:reuse_variables()
example1:
with tf.variable_scope("try"):
#先创建两个变量w1, w2
w2 = tf.get_variable("w1",shape=[2,3,4], dtype=tf.float32)
w3 = tf.get_variable("w2", shape=[2, 3, 4], dtype=tf.float32)
#使用reuse_variables 将刚刚创建的两个变量共享
tf.get_variable_scope().reuse_variables()
w4 = tf.get_variable("w1", shape=[2, 3, 4], dtype=tf.float32)
w5 = tf.get_variable("w2", shape=[2, 3, 4], dtype=tf.float32)
#再进行共享的话,还需要再使用一次reuse_variables()
tf.get_variable_scope().reuse_variables()
w6 = tf.get_variable("w1", shape=[2, 3, 4], dtype=tf.float32)
w7 = tf.get_variable("w2", shape=[2, 3, 4], dtype=tf.float32)
example2:
with tf.variable_scope("RNN"):
for time_step in range(num_steps):
if time_step > 0: tf.get_variable_scope().reuse_variables()
# cell_out: [batch, hidden_size]
(cell_output, state) = cell(inputs[:, time_step, :], state) # 按照顺序向cell输入文本数据
outputs.append(cell_output) # output: shape[num_steps][batch,hidden_size]
这样就可以实现变量共享了