var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="resnet_50/block1")
saver = tf.train.Saver(var_list=var)
saver.restore(sess, "resnet_50.ckpt")
以上代码段为resnet_50的网络只加载name_scope为(resnet_50/block1)内的变量,即指定某一name_scope内的变量进行加载。
var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="resnet_50/block1")
saver = tf.train.Saver(var_list=var)
saver.restore(sess, "resnet_50.ckpt")
以上代码段为resnet_50的网络只加载name_scope为(resnet_50/block1)内的变量,即指定某一name_scope内的变量进行加载。