Network Visualization
一学期网课就这么混完了…CS231n的作业说停就停了3个月,果然人是一种懒惰性很强的动物。细数一下从开始写Assignment1的Q1开始,到今天也快一年了,我真是拖延症晚期患者…好吧,及时弥补,写完这篇关于Network Visualization的博客,暑期燥起来!
Saliency Maps
这次的作业可以分为三个部分,第一个部分是关于显著图的绘制。显著图可以告诉我们究竟是图像中的哪些像素影响了网络的分类结果,即到底是什么部分让网络知道这是帅帅的我,而不是一只傻狗!
为了计算显著图,我们首先需要计算出网络对于当前图片在正确类别上的得分,并且计算出损失在图像每个像素点上的梯度,然后对于每个像素点,取个通道中最大的梯度绝对值作为该像素点处的值。
这一部分需要补充的代码比较简单,正好熟悉一下Pytorch的用法:
def compute_saliency_maps(X, y, model):
"""
Compute a class saliency map using the model for images X and labels y.
Input:
- X: Input images; Tensor of shape (N, 3, H, W)
- y: Labels for X; LongTensor of shape (N,)
- model: A pretrained CNN that will be used to compute the saliency map.
Returns:
- saliency: A Tensor of shape (N, H, W) giving the saliency maps for the input
images.
"""
# Make sure the model is in "test" mode
model.eval()
# Make input tensor require gradient
X.requires_grad_()
saliency = None
##############################################################################
# TODO: Implement this function. Perform a forward and backward pass through #
# the model to compute the gradient of the correct class score with respect #
# to each input image. You first want to compute the loss over the correct #
# scores (we'll combine losses across a batch by summing), and then compute #
# the gradients with a backward pass. #
##############################################################################
# *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
scores = model(X)
loss = scores.gather(1, y.view(-1, 1)).squeeze().sum()
loss.backward()
saliency = torch.max(torch.abs(X.grad), 1)[0]
# *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
##############################################################################
# END OF YOUR CODE #
##############################################################################
return saliency
值得一提的是这里的gather函数,这是Pytorch提供的一个函数,用来实现的功能和我们之前assignment中使用的 s[np.arange(N), y] 一样,用来从指定的数组中提取以y为下标的数值。
之后我们就看到了我们显著图的效果,当当当当!
然后我们看一个小问题
A friend of yours suggests that in order to find