神经网络的三种可视化方法——用keras和MXNet(gluon)实现
目录
keras之父弗朗西斯科肖莱在他的书中提到了CNN的三种常用可视化方法, 同样的算法原理在李宏毅深度学习教程的ExplainableML单元也有提及, 本博客分别使用keras和MXNet(gluon)框架实现了这三种可视化算法, keras实现参考了肖莱书中的代码做了一定修改, gluon版为相同算法的不同框架实现, 后续会补上Pytorch框架的实现, gluon和Pytorch作为动态图框架在可视化上有先天的便利.
github的jupyternotebook地址: link https://github.com/TomMao23/CNN–Visualization
概述
-
第一种可视化方法是 可视化卷积神经网络的中间输出 , 即可视化特征图: 这种方法有助于理解卷积神经网络连续的层如何对输入进行变换,也有助于初步了解卷积神经网络浅层每个过滤器的含义, 缺点是对于较深层的特征图即使可视化出来也难以理解, 对增强深层模型的可解释性没有帮助.
-
第二种方法 可视化卷积神经网络的滤波器 : 固定网络参数, 随机初始化输入图像, 用单个滤波器(卷积核)的输出对输入图像的梯度, 做梯队上升法更新输入图像, 最大化单个滤波器的输出均值, 从而求得 此滤波器的最大响应图像. 这种方法利于人观察到神经网络从浅入深的学到的层级特征, 对深层滤波器可视化可以得到复杂轮廓和接近人概念中的图像. (需要注意的是可视化某个滤波器的最大响应图实际上指的是从输入到这个卷积核输出这条通路, 不仅仅是这个卷积核)
类似鸟的最大响应输入图 -
第三种方法可视化图像中 类激活的热力图: 对于输出的某个类别, 求该类别输出对最后一个卷基层输出特征图的梯度, 对每个通道梯度求均值得到与特征图通道数相同的权重向量描述每个通道的重要程度. 之后每个通道特征图乘以权重后按通道求平均得到"热力图", 可以用双线性插值resize到与图像相同大小. 热力图有助于让人理解神经网络主要根据哪些像素将图片中对象判断为"猫"等其他对象. 谷歌EfficientNet论文中便使用的了这种方法来解释为什么EfficientNet表现更好(热力图聚焦图中对象相关区域得到更多细节)
keras实现
注: 如果后续过程出现了cudnn和显存问题可尝试文件头加上以下代码解决tensorflow后端显存占用的bug
#解决tensorflow后端显存占用的bug
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
keras特征图可视化
首先加载模块和模型, 编写辅助函数把图像转为网络可接受的张量. 本文以VGG16为例做可视化, ResNet等模型的可视化方法与其大同小异
#加载我们的可视化对象VGG16预训练网络, 前两种方法只需要卷积层, 第三种方法热力图需要使用到全连接层, 所以在这里直接加载所有层
from keras.applications import VGG16
from keras import backend as K
import cv2
import numpy as np
import matplotlib.pyplot as plt
model = VGG16(weights='imagenet', include_top=True)
#编写辅助函数对输入网络的图像预处理: 原书使用image模块, 这里改成更常用的opencv
#BGR通道转RGB, 本应该做归一化或标准化但keras的预训练模型实际上接受的输入是归一化前的故注释, 加入批量维度, 注: keras(TensorFlow后端)使用通道在后格式即BHWC不用再改通道维度位置
img_path = "00000001_020.jpg"
img = cv2.imread(img_path)
def image_to_tensor(img):
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (224, 224))
#tensor = img / 255.
tensor = np.expand_dims(img, axis=0)
return tensor
tensor = image_to_tensor(img)
#对比一下形状
print(img.shape)
print(tensor.shape)
#可视化原图, 注意OpenCV默认BGR格式, matplotlib可视化需要转RGB
plt.imshow(img[:, :, ::-1])
输入图像:
利用keras的函数式API构建用于可视化的模型. 实际上是重新创建了一个单输入多输出网络, 结构参数输入同VGG16, 输出为需要可视化特征图的层. 本例我们可视化VGG五个Block的第一个卷基层, 对应模型层下标为[1,4,7,11,15]. activations为一个列表, 5个元素对应了5个层的输出特征图
#使用keras函数式API, 实际上是重新创建了一个单输入多输出网络, 结构参数输入同VGG16, 输出为这几个层
from keras import models
layer_outputs = []
for i, layer in enumerate(model.layers):
if i in [1,4,7,11,15]:
print(layer.name)
layer_outputs.append(layer.output)
activation_model = models.Model(inputs=model.input, outputs=layer_outputs)
#activations为一个列表, 5个元素对应了5个层的输出特征图
activations = activation_model.predict(tensor)
编写把张量转化为用于显示的图像的辅助函数, 尝试可视化block2的第一个卷基层第5个卷积核
#用matplotlib画出每层特征图
#先编写显示的辅助函数
#1.输出张量为(Batch, Hight, Width, Chanels)的格式, 需要去掉批量维度
#2.输出为float32类型, 要显示应该配合可视化工具一般转为uint8
#3.输出的范围较为随机, 为可视化方便采用减均值除以标准差标准化为均值为0的正态分布, 之后乘以64放大范围, 加128把均值移动到128, 对于转uint8溢出部分采用截断处理
def tensor_to_image(tensor):
tensor = tensor[0]
tensor -= tensor.mean()
tensor /= tensor.std()
tensor *= 64
tensor += 128
img = np.clip(tensor, 0, 255).astype('uint8')
return img
# 选其中一张特征图测试可视化
t1 = activations[1][:, :, :, 5] #block2的第一个卷基层第5个卷积核的输出
plt.imshow(tensor_to_image(t1), cmap='viridis')
分别可视化每个Block第一个卷层的特征图
#可视化第一个Block第一个卷基层所有特征图, 64通道64张图
# 把特征图通道移到第一维, 方便遍历
featuremaps = np.transpose(activations[0], [3, 0, 1, 2])
plt.figure(figsize=(40, 10))
for i, featuremap in enumerate(featuremaps):
plt.subplot(4, 16, i+1)
plt.imshow(tensor_to_image(featuremap))
#可视化第二个Block第一个卷基层所有特征图, 128通道128张图
featuremaps = np.transpose(activations[1], [3, 0, 1, 2])
plt.figure(figsize=(20, 10))
for i, featuremap in enumerate(featuremaps):
plt.subplot(8, 16, i+1)
plt.imshow(tensor_to_image(featuremap))
#可视化第三个Block第一个卷基层所有特征图, 256通道256张图
featuremaps = np.transpose(activations[2], [3, 0, 1, 2])
plt.figure(figsize=(20, 20))
for i, featuremap in enumerate(featuremaps):
plt.subplot(16, 16, i+1)
plt.imshow(tensor_to_image(featuremap))
#可视化第四个Block第一个卷基层所有特征图, 512通道512张图
featuremaps = np.transpose(activations[3], [3, 0, 1, 2])
plt.figure(figsize=(20, 40))
for i, featuremap in enumerate(featuremaps):
plt.subplot(32, 16, i