运用Python提取一张图片在网络中的某一层特征,resnet18
需提前下载好resnet18-f37072fd.pth模型
import torch
from torchvision import models, transforms
from torch.autograd import Variable
import torch.nn as nn
from PIL import Image
extract_list = ["conv1", "maxpool", "layer1", "layer2", "layer3", "layer4", "avgpool", "fc"]
img_path = "./A.jpg"
saved_path = "./1.txt"
resnet = models.resnet18(pretrained = False)
resnet.load_state_dict(torch.load('./resnet18-f37072fd.pth'))#加载模型
# print(resnet) #模型结构
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()]
)
img = Image.open(img_path)
img = transform(img)
x = Variable(torch.unsqueeze(img, dim=0).float(), requires_grad=False)
# if use_gpu:
# x = x.cuda()
# resnet = resnet.cuda()
# 中间层特征提取
class Feat