1.模型为resnet152, 默认输出图片大小是224*224*3
2.获取除去全连接层的模型
import torch
import torch.nn as nn
import torchvision.models as models
from torch.autograd import Variable
resnet152 = models.resnet152(pretrained=True)
modules=list(resnet152.children())[:-1]
resnet152=nn.Sequential(*modules)
for p in resnet152.parameters():
p.requires_grad = False
2.使用新模型处理图片
import cv2
import torch
import torch.nn as nn
import torchvision.models as models
from torch.autograd import Variable
from PIL import Image
import torchvision.transforms as transforms
class ResNet152Bottom(nn.Module):
def __init__(self, original_model):
super(ResNet152Bottom, self).__init__()
self.features = nn.Sequential(*list(original_model.children())[:-1])
def forward(self, x):
x = self.features(