模
型
的
预
测
模型的预测
模型的预测
import os
import time
from PIL import Image
from major_dataset import LoadDataset
import major_config
import torch
import torchvision.transforms as transforms
from torchsummary import summary
classes = ["airplane", "automobile", "bird", "cat", "deer","dog", "frog", "horse", "ship", "truck"]
inference_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(major_config.norm_mean, major_config.norm_std),
])
def preprocessing(img,transform = None):
if transforms is None:
raise Exception("无transform进行预处理")
img_tensor = transform(img)
return img_tensor
def get_model(saved_model_path=major_config.path_saved_model,visual_model=False,input_size=(3,32,32)):
net = major_config.model
net.load_state_dict(torch.load(saved_model_path))
if visual_model:
summary(net, input_size=input_size, device="cpu")
return net
if __name__ == "__main__":
img_path = r"D:\Classification_Demo\major_dataset_repo\split_data\test\0\0_116.png"
model_path = major_config.path_saved_model
net = get_model(model_path,False,input_size=(3,32,32))
net.to(major_config.device)
net.eval()
with torch.no_grad():
img_rgb = Image.open(img_path).convert('RGB')
img_tensor = preprocessing(img_rgb,inference_transform)
img_tensor.unsqueeze_(0)
img_tensor = img_tensor.to(major_config.device)
time_start = time.time()
outputs = net(img_tensor)
time_end = time.time()
print(outputs)
_,pred_int = torch.max(outputs,1)
print(pred_int)
pred_str = classes[int(pred_int.cuda().data.cpu().numpy())]
print(pred_str)
with torch.no_grad():
path = r"D:\Classification_Demo\major_dataset_repo\split_data\test\0"
files_list = os.listdir(path)
file_path_list = [os.path.join(path, img) for img in files_list]
for i in range(100):
img_rgb = Image.open(file_path_list[i]).convert('RGB')
img_tensor = preprocessing(img_rgb,inference_transform)
img_tensor.unsqueeze_(0)
img_tensor = img_tensor.to(major_config.device)
time_start = time.time()
outputs = net(img_tensor)
time_end = time.time()
print("所耗时间:",time_end - time_start)
print(outputs)
_,pred_int = torch.max(outputs,1)
print(pred_int)
pred_str = classes[int(pred_int.cuda().data.cpu().numpy())]
print(pred_str)