import os
os.environ['NLS_LANG'] = 'SIMPLIFIED CHINESE_CHINA.UTF8'
import time
import json
import torch
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt
import torchvision.models as models
BASE_DIR = os.path.dirname(os.path.abspath(__file__))#获取目录位置,为后面实现相对路径97-99
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
def img_transform(img_rgb, transform=None):
"""
将数据转换为模型读取的形式
:param img_rgb: PIL Image
:param transform: torchvision.transform
:return: tensor
"""
if transform is None:
raise ValueError("找不到transform!必须有transform对img进行处理")
img_t = transform(img_rgb)
return img_t
def load_class_names(p_clsnames, p_clsnames_cn):
"""
加载标签名
:param p_clsnames:
:param p_clsnames_cn:
:return:
"""
with open(p_clsnames, "r") as f:
class_names = json.load(f)
with open(p_clsnames_cn, encoding='UTF-8') as f: # 设置文件对象
class_names_cn = f.readlines()
return class_names, class_names_cn
def get_model(path_state_dict, vis_model=False):
"""
创建模型,加载参数
:param path_state_dict:
:return:
"""
model = models.alexnet()
pretrained_state_dict = torch.load(path_state_dict)#加载预训练模型,已经训练好的参数,直接得出分类结果
model.load_state_dict(pretrained_state_dict)
model.eval()#drop BN 测试和训练时候状态不一致,所以测试时候要重新设置回去
if vis_model:
from torchsummary import summary
summary(model, input_size=(3, 224, 224), device="cpu")#torch 选修课有,将gpu的tensor在运行过程中输出来,进行可视化网络层,需要回去学习
model.to(device)
return model
def process_img(path_img):
# hard code 这是啥?
norm_mean = [0.485, 0.456, 0.406]#R G B 有了了BN基本没用了
norm_std = [0.229, 0.224, 0.225]
inference_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop((224, 224)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])#预处理图片#0-255变为0-1
# path --> img
img_rgb = Image.open(path_img).convert('RGB')#????
# img --> tensor
img_tensor = img_transform(img_rgb, inference_transform)#???
img_tensor.unsqueeze_(0) # chw --> bchw tensor里面处理是4D张量,增加一个通道是batch
img_tensor = img_tensor.to(device)
return img_tensor, img_rgb
if __name__ == "__main__":
# config
path_state_dict = os.path.join(BASE_DIR, "..", "data", "alexnet-owt-4df8aa71.pth")#组成相对路径
# path_img = os.path.join(BASE_DIR, "..", "data", "Golden Retriever from baidu.jpg")
path_img = os.path.join(BASE_DIR, "..", "data", "tiger cat.jpg")
path_classnames = os.path.join(BASE_DIR, "..", "data", "imagenet1000.json")
path_classnames_cn = os.path.join(BASE_DIR, "..", "data", "imagenet_classnames.txt")
# load class names
cls_n, cls_n_cn = load_class_names(path_classnames, path_classnames_cn)
# 1/5 load img
img_tensor, img_rgb = process_img(path_img)
# 2/5 load model
alexnet_model = get_model(path_state_dict, True)
# 3/5 inference tensor --> vector
with torch.no_grad():#测试,只做前向传播,所以不要梯度
time_tic = time.time()
outputs = alexnet_model(img_tensor)
time_toc = time.time()
# 4/5 index to class names
_, pred_int = torch.max(outputs.data, 1)#得出类别
_, top5_idx = torch.topk(outputs.data, 5, dim=1)#获取前5类
pred_idx = int(pred_int.cpu().numpy())
pred_str, pred_cn = cls_n[pred_idx], cls_n_cn[pred_idx]
print("img: {} is: {}\n{}".format(os.path.basename(path_img), pred_str, pred_cn))
print("time consuming:{:.2f}s".format(time_toc - time_tic))
# 5/5 visualization
plt.imshow(img_rgb)
plt.title("predict:{}".format(pred_str))
top5_num = top5_idx.cpu().numpy().squeeze()
text_str = [cls_n[t] for t in top5_num]
for idx in range(len(top5_num)):
plt.text(5, 15+idx*30, "top {}:{}".format(idx+1, text_str[idx]), bbox=dict(fc='yellow'))
plt.show()
alexnet进行判别
最新推荐文章于 2022-05-07 10:34:03 发布