来源:《Pytorch深度学习实战》,2.1,一个识别图像主体的预训练网络
from torchvision import models
from torchvision import transforms
from PIL import Image
import torch
dir(models)
alexnet = models.AlexNet()
resnet = models.resnet101(pretrained=True)
resnet
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
img = Image.open("D:\\PyCharm\\pythonProject\\bobby.PNG")
img.show()
img_t = preprocess(img)
batch_t = torch.unsqueeze(img_t, 0)
resnet.eval()
out = resnet(batch_t)
out
with open("D:\\PyCharm\\pythonProject\\imagenet_classes.txt") as f:
labels = [line.strip() for line in f.readline()]
_, index = torch.max(out, 1)
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
labels[index[0]], percentage[index[0]].item()