实现了使用grad-cam查看不同网络模型的激活图的功能
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import print_function
import cv2
import os
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from models import *
import torch.backends.cudnn as cudnn
def img_transform(img_in, transform):
"""
将img进行预处理,并转换成模型输入所需的形式—— B*C*H*W
:param img_roi: np.array
:return:
"""
img = img_in.copy()
img = Image.fromarray(np.uint8(img))
img = transform(img)
img = img.unsqueeze(0) # C*H*W --> B*C*H*W
return img
def img_preprocess(img_in):
"""
读取图片,转为模型可读的形式
:param img_in: ndarray, [H, W, C]
:return: PIL.image
"""
img = img_in.copy()
img = cv2.resize(img,(256, 256))
img = img[:, :, ::-1] # BGR --> RGB
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.28201964, 0.2816544, 0.21802017],[0.25426927, 0.24146019, 0.1867752])
])
img_input = img_transform(img, transform)
return img_input
def backward_hook(module, grad_in, grad_out):
grad_block.append(grad_out[0].detach())
def farward_hook(module, input, output):
fmap_block.append(output)
def show_cam_on_image(img, mask,ClassName, out_dir,count):
heatmap = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
cam = heatmap + np.float32(img)
cam = cam / np.max(cam)
CamName = ClassName + str(count) + "lapcam.jpg"
RawName = ClassName + str(count) + "raw.jpg"
path_cam_img = os.path.join(out_dir, CamName)
path_raw_img = os.path.join(out_dir, RawName)
if not os.path.exists(out_dir):
os.makedirs(out_dir)
cv2.imwrite(path_cam_img, np.uint8(255 * cam))
cv2.imwrite(path_raw_img, np.uint8(255 * img))
def comp_class_vec(ouput_vec, index=None):
"""
计算类向量
:param ouput_vec: tensor
:param index: int,指定类别
:return: tensor
"""
if not index:
index = np.argmax(ouput_vec.cpu().data.numpy())
else:
index = np.array(index)
index = index[np.newaxis, np.newaxis]
index = torch.from_numpy(index)
one_hot = torch.zeros(1, 7).scatter_(1, index, 1)
one_hot.requires_grad = True
class_vec = torch.sum(one_hot * output.cpu()) # one_hot = 11.8605
return class_vec
def gen_cam(feature_map, grads):
"""
依据梯度和特征图,生成cam
:param feature_map: np.array, in [C, H, W]
:param grads: np.array, in [C, H, W]
:return: np.array, [H, W]
"""
cam = np.zeros(feature_map.shape[1:], dtype=np.float32) # cam shape (H, W)
weights = np.mean(grads, axis=(1, 2)) #
for i, w in enumerate(weights):
cam += w * feature_map[i, :, :]
cam = np.maximum(cam, 0)
cam = cv2.resize(cam, (256, 256))
cam -= np.min(cam)
cam /= np.max(cam)
return cam
####----------------------------main----------------------------------------####
if __name__ == '__main__':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
####----------------------------setting the paths----------------------------####
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
print(BASE_DIR)
path_img = os.path.join(BASE_DIR, "test_img")
output_dir = os.path.join(BASE_DIR, "Result", "backward_hook_cam")
####---------------------------import classes-------------------------------####
classes = ('antimonite','arsenopyrite','blende','chalcopyrite','galena','native_gold','pyrite')
####--------------------------图片读取;网络加载----------------------------####
fmap_block = list()
grad_block = list()
####------------------------------Model-------------------------------------####
# net = vgg16(pretrained=True,NumClass=len(classes))
# net = resnext50_32x4d(pretrained = True, NumClass=len(classes))
net = create_RepVGG_B1g2(deploy=False,pretrained = True, clas=len(classes))
net_name = net.__class__.__name__
net = net.to(device)
if device == 'cuda':
net = torch.nn.DataParallel(net)
cudnn.benchmark = True
net.eval()
####---------------------------Importing weight-----------------------------####
print('==> Resuming from checkpoint..')
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
if net_name == 'ResNet':
path = 'checkpoint\\resnext_acc99.88.pth'
elif net_name == 'RepVGG':
path = 'checkpoint\\repvgg_99.82.pth'
elif net_name == 'VGG':
path = 'checkpoint\\vgg_acc99.70.pth'
else:
print('error')
checkpoint = torch.load((path), map_location = torch.device('cpu'))
'''
/home/deyiwang/191112/pytorch-cifar-master/checkpoint/NWPU45
state = {
'net': net.state_dict(),
'acc': acc,
'epoch': epoch,
}
'''
net.load_state_dict(checkpoint['net'])
for root, dirs, files in os.walk(path_img, topdown = True):
count = 0
for name in files:
# 注册hook
if net_name == 'ResNet':
net.module.layer4.register_forward_hook(farward_hook)
net.module.layer4.register_backward_hook(backward_hook)
elif net_name == 'RepVGG':
net.module.stage4.register_forward_hook(farward_hook)
net.module.stage4.register_backward_hook(backward_hook)
elif net_name == 'VGG':
net.module.features[40].register_forward_hook(farward_hook)
net.module.features[40].register_backward_hook(backward_hook)
else:
print('error')
print('======>>>>>>hook done<<<<<<======')
im_dir = os.path.join(path_img, name)
####--------------------------图片读取;网络加载----------------------------####
fmap_block = list()
grad_block = list()
img = cv2.imread(str(im_dir), 1) # H*W*C
img_input = img_preprocess(img)
# forward
output = net(img_input)
idx = np.argmax(output.cpu().data.numpy())
print("predict: {}".format(classes[idx]))
# backward
net.zero_grad()
class_loss = comp_class_vec(output)
class_loss.backward()
print('======>>>>>>backward done<<<<<<======')
# 生成cam
grads_val = grad_block[0].cpu().data.numpy().squeeze()
fmap = fmap_block[0].cpu().data.numpy().squeeze()
cam = gen_cam(fmap, grads_val)
# 保存cam图片
ClassName = classes[idx]
img_show = np.float32(cv2.resize(img, (256, 256))) / 255
show_cam_on_image(img_show, cam, ClassName, output_dir,count)
count += 1