- import tensorflow as tf
- # Load the VGG-16 model in the default graph
- vgg_saver = tf.train.import_meta_graph(dir + '/vgg/results/vgg-16.meta')
- # Access the graph
- vgg_graph = tf.get_default_graph()
- # Retrieve VGG inputs
self.x_plh = vgg_graph.get_tensor_by_name('input:0')
# Choose which node you want to connect your own graph
output_conv =vgg_graph.get_tensor_by_name('conv1_2:0')
# output_conv =vgg_graph.get_tensor_by_name('conv2_2:0')
# output_conv =vgg_graph.get_tensor_by_name('conv3_3:0') - # output_conv =vgg_graph.get_tensor_by_name('conv4_3:0')
- # output_conv =vgg_graph.get_tensor_by_name('conv5_3:0')
- # Stop the gradient for fine-tuning
- output_conv_sg = tf.stop_gradient(output_conv) # It's an identity functio
- # Build further operations
- output_conv_shape = output_conv_sg.get_shape().as_list()
- W1 = tf.get_variable('W1', shape=[1, 1, output_conv_shape[3], 32], initializer=tf.random_normal_initializer(stddev=1e-1))
- b1 = tf.get_variable('b1', shape=[32], initializer=tf.constant_initializer(0.1))
- z1 = tf.nn.conv2d(output_conv_sg, W1, strides=[1, 1, 1, 1], padding='SAME') + b1
- a = tf.nn.relu(z1)
tensorflow微调模型,如何中断梯度
最新推荐文章于 2022-05-23 13:47:57 发布