由于 TensorFlow 程序的未连接部分可能需要创建变量,因此能有一种方式访问所有变量有时十分受用。为此,TensorFlow 提供了集合,它们是张量或其他对象(如 tf.Variable
实例)的命名列表。
默认情况下,每个 tf.Variable
都放置在以下两个集合中:
tf.GraphKeys.GLOBAL_VARIABLES
- 可以在多台设备间共享的变量,tf.GraphKeys.TRAINABLE_VARIABLES
- TensorFlow 将计算其梯度的变量。
如果您不希望变量可训练,可以将其添加到 tf.GraphKeys.LOCAL_VARIABLES
集合中。例如,以下代码段展示了如何将名为 my_local
的变量添加到此集合中:
my_local = tf.get_variable("my_local", shape=(),
collections=[tf.GraphKeys.LOCAL_VARIABLES])
或者,您可以将 trainable=False
指定为 tf.get_variable
的参数:
my_non_trainable = tf.get_variable("my_non_trainable",
shape=(),
trainable=False)
您也可以使用自己的集合。集合名称可为任何字符串,且您无需显式创建集合。创建变量(或任何其他对象)后,要将其添加到集合中,请调用 tf.add_to_collection
。例如,以下代码将名为 my_local
的现有变量添加到名为 my_collection_name
的集合中:
tf.add_to_collection("my_collection_name", my_local)
要检索您放置在某个集合中的所有变量(或其他对象)的列表,您可以使用:
tf.get_collection("my_collection_name")