1--常用预训练模型下载
如果不想额外下载预训练模型,可通过以下代码设置自动下载对应的权重文件(下载速度可能较慢):
vgg16 = models.vgg16(pretrained=True)
# 替换下面代码:
# vgg16 = models.vgg16()
# weights = torch.load('./vgg16-397923af.pth')
# vgg16.load_state_dict(weights)
2--使用VGG16预训练模型
from torchvision import models
import cv2
from torchvision import transforms
import torch
from PIL import Image
import numpy as np
# 初始化模型
device = 0
vgg16 = models.vgg16().to(device)
weights = torch.load('./vgg16-397923af.pth')
vgg16.load_state_dict(weights)
# 前处理
normalize = transforms.Normalize(mean = [0.485, 0.456, 0.406], # imagenet dataset的均值
std = [0.229, 0.224, 0.225])
tran = transforms.Compose([transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize(mean = [0.485, 0.456, 0.406],
std = [0.229, 0.224, 0.225])])
if __name__ == "__main__":
# 读取图片
img = cv2.imread('./test1.jpg')
# 前处理
img = Image.fromarray(np.uint8(img)).convert('RGB') # [3, 224, 224]
img = tran(img)
img.unsqueeze_(dim=0) # [1, 3, 224, 224]
# 推理
output = vgg16(img.to(device)) # [1, 1000]
# 后处理
output = output.data[0] # [1000]
output = output.cpu().detach().numpy()
print(output.shape)