import torch import torchvision.utils from PIL import Image from torchvision import models,transforms class Vgghook: def __init__(self,vgg): self.imgs={} for idx in range(31): vgg.features[idx].register_forward_hook(self.create_hook_fn(idx)) def create_hook_fn(self,idx): def hook_fn(m,i,o): #模块,模块输入,模块输出 self.imgs[idx]=o.cpu() return hook_fn model=models.vgg16(pretrained=True) model.eval() vgg_hook=Vgghook(model) tfs = transforms.Compose([ transforms.ToTensor() ]) img = Image.open("b.jpeg") img = img.convert("RGB") x = tfs(img) # [H,W,C] --> [C,H,W] x = x.unsqueeze(0) # [C,H,W] -> [1,C,H,W] y = model(x) # 置信度,表示当前样本属于各个类别的置信度 print(y.argmax(-1)) name='小狗' ii = transforms.Resize(size=(50, 60)) for i in range(len(vgg_hook.imgs)): torchvision.utils.save_image(ii(vgg_hook.imgs[i].permute(1, 0, 2, 3)), f"output/{i}_{name}.png", nrow=8) ''' https://blog.csdn.net/Thorn__/article/details/108728393 需要先建立文件夹 '''
vgg16可视化
最新推荐文章于 2024-08-23 22:22:20 发布