def show_feature_map(feature_map):
feature_map = feature_map.squeeze(0)
feature_map = feature_map.cpu().numpy()
feature_map_num = feature_map.shape[0]
row_num = np.ceil(np.sqrt(feature_map_num))
plt.figure()
for index in range(1, feature_map_num + 1):
plt.subplot(row_num, row_num, index)
plt.imshow(feature_map[index - 1], cmap='viridis') # gray
plt.axis('off')
# imageio.imsave("mnist/three/"+str(index) + ".png", feature_map[index - 1])
plt.show()
Pytorch 特征图可视化函数
最新推荐文章于 2024-10-30 19:27:31 发布