变量必须先初始化后才可使用。如果您在低级别 TensorFlow API 中进行编程(即您在显式创建自己的图和会话),则必须明确初始化变量。tf.contrib.slim
、tf.estimator.Estimator
和 Keras
等大多数高级框架在训练模型前会自动为您初始化变量。
显式初始化在其他方面很有用。它允许您在从检查点重新加载模型时不用重新运行潜在资源消耗大的初始化器,并允许在分布式设置中共享随机初始化的变量时具有确定性。
要在训练开始前一次性初始化所有可训练变量,请调用 tf.global_variables_initializer()
。此函数会返回一个操作,负责初始化 tf.GraphKeys.GLOBAL_VARIABLES
集合中的所有变量。运行此操作会初始化所有变量。例如:
session.run(tf.global_variables_initializer())
# Now all variables are initialized.
如果您确实需要自行初始化变量,则可以运行变量的初始化器操作。例如:
session.run(my_variable.initializer)
您可以查询哪些变量尚未初始化。例如,以下代码会打印所有尚未初始化的变量名称:
print(session.run(tf.report_uninitialized_variables()))
请注意,默认情况下,tf.global_variables_initializer
不会指定变量的初始化顺序。因此,如果变量的初始值取决于另一变量的值,那么很有可能会出现错误。任何时候,如果您在并非所有变量都已初始化的上下文中使用某个变量值(例如在初始化某个变量时使用另一变量的值),最好使用 variable.initialized_value()
,而非 variable
:
v = tf.get_variable("v", shape=(), initializer=tf.zeros_initializer())
w = tf.get_variable("w", initializer=v.initialized_value() + 1)