这是无意看到这个博主写的,是把模型在推理的时候,把图片经过的卷积层进行可视化,蛮好的,以下代码都是从博主拷贝过来,为了方便自己看,就拷贝贴了,防止博主删除,如果博主觉的侵权,还请告知,我删除。
import cv2
import torch
import numpy as np
import torchvision
import torch.nn as nn
from PIL import Image
import torch.nn.functional as F
from torchvision import transforms
from models.experimental import attempt_load
fmap_dict = dict()
count = 0
weights = r'D:\github\YoloV5\yolov5s.pt'
model = attempt_load(weights, map_location='cpu') # load FP32 model
def forward_hook(module, input, output):
global count
key_name = 'layer_' + str(count)
count += 1
fmap_dict[key_name].append(output)
# modules_for_plot = (torch.nn.ReLU, torch.nn.Conv2d,
# torch.nn.MaxPool2d, torch.nn.AdaptiveAvgPool2d)
modules_for_plot = (torch.nn.Conv2d)
n = 0
for name, sub_module in model.named_modules(): # named_modules()返回网络的子网络层及其名称
if isinstance(sub_module, modules_for_plot):
key_name = 'layer_' + str(n)
n += 1
fmap_dict.setdefault(key_name, list())
sub_module.register_forward_hook(forward_hook)
def readimg():
path_img = r"D:\data\coco128\images\train2017\000000000165.jpg" # your path to image
imgz = 640
img = cv2.imread(path_img)
img = cv2.resize(img, (imgz, imgz))
cv2.imshow('src', img)
cv2.waitKey(1)
img_transforms = transforms.Compose([
transforms.Resize((imgz, imgz)),
transforms.ToTensor()])
img_pil = Image.open(path_img).convert('RGB')
if img_transforms is not None:
img_tensor = img_transforms(img_pil) * 255.0
img_tensor.unsqueeze_(0) # chw --> bchw
return img_tensor
if __name__ == '__main__':
img_tensor = readimg()
output = model(img_tensor)
for layer_name, fmap_list in fmap_dict.items():
fmap = fmap_list[0]
print(layer_name, fmap.shape)
if fmap.shape[0] < fmap.shape[1]:
fmap.transpose_(0, 1)
print(fmap.shape[0])
# if fmap.shape[0] == 128:
print(f'enter')
fmap = fmap.sigmoid()
fmap = F.interpolate(fmap, size=[224, 224], mode="bilinear")
fmap_grid = torchvision.utils.make_grid(fmap, normalize=True, scale_each=True, nrow=6, pad_value=255)
fmap_grid = fmap_grid.permute(1, 2, 0)
cv2.imshow(('vis%d' % 0), np.array(fmap_grid.numpy() * 255, dtype=np.uint8))
cv2.waitKey(3000)