tf.reset_default_graph() #加入这句话,可以重新创建图,否则会报错
with tf.Session() as sess:
NetVLAD = lp.NetVLAD(feature_size=1024, max_samples=1, cluster_size=5,
output_dim=1024, gating=False, add_batch_norm=True,
is_training=True)
batch_index = 0
sess.run(tf.global_variables_initializer())
a = a.eval()
with tf.Session() as sess:
NetVLAD = lp.NetVLAD(feature_size=1024, max_samples=1, cluster_size=5,
output_dim=1024, gating=False, add_batch_norm=True,
is_training=True)
batch_index = 0
train_batch,batch_index=
get_train_batches(dataset_train_index,x_train_dataset,y_train_dataset,batch_index,batch_size=100)
a = NetVLAD.forward(train_batch['x_train_batch'])sess.run(tf.global_variables_initializer())
a = a.eval()
print a.shape,type(a)
几点需要注意,一定要写with tf.Session() as sess: 和 sess.run(tf.global_variables_initializer())
a.eval()将tensor转为array