Class Activation Mapping
论文:CVPR2016《Learning Deep Features for Discriminative Localization》
代码:https://github.com/acheketa/pytorch-CAM/blob/master/update.py
1、首先定义并训练好CNN网络,很重要的一点是网络的最后一个卷积层必须只有一个通道,并且紧跟着全连接层(最后一层),可以参考github上inception-v3网络结构,下面是我自己的网络结构最后两层。
这是我的网络里最后的设计,其中conv3是要观察的热力图,fcl1是最后dense到类数。
假设网路训练好,得到一个best_net。
2、CAM代码
# generate class activation mapping for the top1 prediction
def returnCAM(feature_conv, weight_softmax, class_idx):
# generate the class activation maps upsample to 256x256
size_upsample = (256, 256)
bz, nc, h, w = feature_conv.shape
output_cam = []
for idx in class_idx:
#cam = weight_softmax[class_idx].dot(feature_conv.reshape((nc,h*w)))
cam = weight_softmax[class_idx]*(feature_conv.reshape((nc,h*w)))
cam = cam.reshape(h, w)
cam = cam - np.min(cam)
cam_img = cam / np.max(cam)
cam_img = np.uint8(255 * cam_img)
output_cam.append(cv2.resize(cam_img, size_upsample))
return output_cam
# hook the feature extractor
features_blobs = []
def hook_feature(module, input, output):
features_blobs.append(output.data.cpu().numpy())
#last conv layer followed with one channel by last fully connected layer
final_conv = 'conv3'
best_net._modules.get(final_conv).register_forward_hook(hook_feature)
#get weights parameters
params = list(best_net.parameters())
#get the last and second last weights, like [classes, hiden nodes]
weight_softmax = np.squeeze(params[-2].data.cpu().numpy())
# define class type
classes = {0: 'Pos', 1: 'Neg'}
#read image
root='/test.png'
img = []
img.append( cv2.resize(cv2.imread(root).astype(np.float32), (256, 256)))#(256, 256) is the model input size
data = torch.from_numpy(np.array(img)).type(torch.FloatTensor).cuda()
logit = best_net(data.permute(0, 3, 1, 2))#forword
h_x = F.softmax(logit, dim=1).data.squeeze()#softmax
probs, idx = h_x.sort(0, True) #probabilities of classes
# output: the prediction
for i in range(0, 2):
line = '{:.3f} -> {}'.format(probs[i], classes[idx[i].item()])
print(line)
#get the class activation maps
CAMs = returnCAM(features_blobs[0], weight_softmax, [idx[0].item()])
# render the CAM and show
print('output CAM.jpg for the top1 prediction: %s' % classes[idx[0].item()])
img = cv2.imread(root)
height, width, _ = img.shape
CAM = cv2.resize(CAMs[0], (width, height))
heatmap = cv2.applyColorMap(CAM, cv2.COLORMAP_JET)
result = heatmap * 0.3 + img * 0.5
cv2.imwrite('cam.jpg', result)
这个代码是核心。