炼丹时经常需要可视化中间层的结果进行分析,本文采用tensorflow2.1版本,进行代码展示
def show_middle(test_dataset_list, model, checkpoint_dir):
tf.keras.backend.clear_session()
#读取测试集
test_dataset = load_dataset_test()
#读取保存的模型参数
latest = tf.train.latest_checkpoint(checkpoint_dir)
#导入保存的模型结构
model = Model(model_name)
#加载参数
model.load_weights(latest)
outputs = []
##################
visualization_model = models.Model(inputs=model.input,outputs=model.get_layer('tf_op_layer_tmp').output)
#############
for sample, mask in test_dataset:
output = visualization_model.predict(sample)
outputs.append(output)
outputs = np.mean(outputs, axis = 0)
print(outputs)