导入相关包
import numpy as np
from keras.models import load_model
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.pyplot as plt
from keras.preprocessing import image
from keras import layers
from keras import models
import math
载入模型
work_dir = 'D:/py/SAR'
# 载入模型
def read_model():
model = load_model(work_dir + '/model/model_s5_nofilt.h5')
return model
model = read_model()
输出模型每层的名字
layer_names = [layer.name for layer in model.layers]
print(layer_names)
输出特征图
def visible():
model = load_model(work_dir + '/model/model_s5_nofilt.h5')
layer_outputs = [layer.output for layer in model.layers[:5]]
activation_model = models.Model(inputs=model.input, outputs=layer_outputs)
img_path = 'D:/py/SAR/mstar/train/2S1/HB14931.jpg'
img = image.load_img(img_path, target_size=(100, 100))
img_tensor = image.img_to_array(img)
img_tensor = img_tensor / 255.0
img_tensor = np.expand_dims(img_tensor, axis=0)
activations = activation_model.predict(img_tensor)
for k in range(len(activations)):
first_layer_activation = activations[k]
print(first_layer_activation.shape)
h = first_layer_activation.shape[1]
w = first_layer_activation.shape[2]
#num_channels = first_layer_activation.shape[3]
num_channels = 6 #若num_channels等于first_layer_activation.shape[3],则相当于将全部通道特征图输出
cols = 6
rows = math.ceil(num_channels/cols)
img_grid = np.zeros(((h*rows,w*cols)))
for i in range(num_channels):
channel_image = first_layer_activation[0, :, :, i]
# col列。
f_r = math.ceil((i+1)/cols)
f_c = (i+1)if f_r==1 else ((i+1)-(f_r-1)*cols)
'''
对图像进行增强等操作,便于人眼观察
channel_image -= channel_image.mean()
channel_image /= channel_image.std()
channel_image *= 64
channel_image += 128
channel_image = np.clip(channel_image, 0, 255).astype('uint8')
'''
img_grid[(f_r-1)*h:f_r*h,(f_c-1)*h:f_c*h]=channel_image
#plt.xticks([])
#plt.yticks([])
# plt.savefig(str(k) + ".jpg", dpi=100)
scale=20. / h
plt.figure(figsize=(scale*h,scale))
plt.imshow(img_grid,aspect='equal',cmap='viridis')
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.savefig(str(k) + "@.jpg", dpi=300,bbox_inches='tight')
plt.show()