Tensorflow查看网络、冻结变量和迁移训练
(Inspect network structure, freeze graph variables, and finetune/transfer learning in Tensorflow)
1. 查看网络结构和参数
python
/usr/local/lib/python2.7/dist-packages/tensorflow/python/tools/inspect_checkpoint.py
--file_name=model.ckpt-1562770
--tensor_name=unit_1_2/sub1/conv1/DW
源码中的inspect_checkpoint.py可以看ckpt文件中的层和某层的权重值
如果只有--file_name就只显示层,如果还有--tensor_name就能显示那一层的权重
2. 只训练graph中部分变量(相当于冻结了其他变量)
Tensorflow在构建graph的过程中会默认自动收集一些变量名到对应的Collection。例如TRAINABLE_VARIABLES就是所有可训练的变量集合。
因此可以通过使用tf.get_collection,指定TRAINABLE_VARIABLES,使其仅包含我们需要重新训练的变量,来冻结其他变量的训练。
例子如下:
first_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
"unit_last")
trainable_variables = first_train_vars
#print trainable_variables
grads = self.optimizer.compute_gradients(self.cost, self.trainable_variables)
3. 更改graph后恢复训练
根据monitored_session.py,使用MonitoredTrainingSession来开启控制Session的时候,若指定的checkpoint路径中有上次的存档,则现有源码只能严格按照之前训练恢复。因此我们需要一个空的checkpoint路径,此时MonitoredTrainingSession就会执行init_op以及init_fn。在init_fn中自己添加恢复函数,并把init_fn作为参数加入MonitoredTrainingSession中的scaffold即可。
例子如下:
variables_to_restore = tf.contrib.framework.get_variables_to_restore(
exclude=['logit'])
init_assign_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint(
ckpt.model_checkpoint_path, variables_to_restore)
def InitAssignFn(scaffold, sess):
sess.run(init_assign_op, init_feed_dict)
scaffold = tf.train.Scaffold(saver=tf.train.Saver(), init_fn=InitAssignFn)