用 TensorFlow 比较多的同学,会发现 reuse variable 来建立模型 (graph) 有时候是必须的,比如建立RNN模型时,num_steps 在 training 和 testing 的时候往往是不同的。
以 dropout 为例: 在 testing 的时候,应该是关闭的。而在 training的时候是启用的。
TensorLayer 有2个简单的方法解决实现这个
方法1 这下面连接的例子中,Layer 内部通过 placeholder 来设置dropout keeping probabilities,当testing的时候,probabilities被设为 1
TensorLayer simple example : http://tensorlayercn.readthedocs.io/zh/latest/user/tutorial.html#tensorlayer
# 训练网络
tl.utils.fit(sess, network, train_op, cost, X_train, y_train, x, y_,
acc=acc, batch_size=500, n_epoch=500, print_freq=5,
X_val=X_val, y_val=y_val, eval_train=False)
# 测试网络
tl.utils.test(sess, network, acc, X_test, y_test, x, y_, batch_size=None, cost=cost)
方法2 上面的代码使用了TensorLayer 提供的傻瓜式函数,和keras、tflearn差不多。但 TensorLayer 作者鼓励大家使用 TensorFlow 的原生方法。
TensorLayer MNIST examples : https://github.com/zsdonghao/tensorlayer/blob/master/tutorial_mnist.py
# 训练时启动dropout
feed_dict = {x: X_train_a, y_: y_train_a}
feed_dict.update( network.all_drop ) # enable dropout or dropconnect layers
sess.run(train_op, feed_dict=feed_dict)
# 测试时关闭dropout,把probabilities 全设为1,放入feed_dict
dp_dict = tl.utils.dict_to_one( network.all_drop )
feed_dict = {x: X_val, y_: y_val}
feed_dict.update(dp_dict)
比如重复使用 variable 的情况:RNN为例,除了dropout 关闭外,还需要使用不同的 num_steps,那就一定要建立不同的 computation graph 了。
graph reuse variables。可以这样实现:
TensorLayer PTB tutorial: https://github.com/zsdonghao/tensorlayer/blob/master/tutorial_ptb_lstm.py
这个代码中,最关键的代码是:
def inference(x, is_training, num_steps, reuse=None):
with tf.variable_scope("model", reuse=reuse): # reuse=True时,则让TensorFlow 知道 reuse variable
tl.layers.set_name_reuse(reuse) # reuse = True时,TensorLayer allows reuse the same Layer name
network = tl.layers.EmbeddingInputlayer(.....
.....
......
这样,你就可以如下建立多个 graph with the same variables了。
# Inference for Training
network, lstm1, lstm2 = inference(input_data, is_training=True, num_steps=num_steps, reuse=None)
# Inference for Validating
network_val, lstm1_val, lstm2_val = inference(input_data, is_training=False, num_steps=num_steps, reuse=True)
# Inference for Testing (Evaluation)
network_test, lstm1_test, lstm2_test = inference(input_data_test, is_training=False, num_steps=1, reuse=True)