tensorflow version: v 1.13
在tf.estimator.WarmStartSettings
里, 是可以设置具体restore
哪些变量, 里面有两个参数,
-
ckpt_to_initialize_from
, 就是需要restore
的ckpt地址, 例如xx/model.ckpt-xx
-
vars_to_warm_start
, 就是判断哪些变量需要被restore的. 需要说明的是, 默认值是".*"
,它会使得程序去resotre
所有的TRAINABLE_VARIABLES
变量, 而这个变量集合里是不含有bn层的moving_mean
及moving_variance
参数的, 按照官网的解释就是:Defaults to '.*', which warm-starts all variables in the TRAINABLE_VARIABLES collection. Note that this excludes variables such as accumulators and moving statistics from batch norm.
因此有两种解决方法, 1是根据你需要
restore
的变量scope名字, 显式的指定需要restore
的变量, 如有多个名字, 可以放在一个列表里, 变量名字支持正则表达式寻找. 2 是传入参数[".*"]
, 让程序去从Global variable
里去恢复所有的变量, 这样就可以把bn层的moving_mean
及moving_variance
参数成功restore
.