模型的预测

模 型 的 预 测 模型的预测

import os
import time
from PIL import Image
from major_dataset import LoadDataset
import major_config
import torch
import torchvision.transforms as transforms
from torchsummary import summary
# 标签和类别的映射关系
classes = ["airplane", "automobile", "bird", "cat", "deer","dog", "frog", "horse", "ship", "truck"]

# 1.model.eval()
# 2.torch.no_grad()
# 3.数据预处理保持一致
# 4.预测时间的计算

inference_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(major_config.norm_mean, major_config.norm_std),
])

def preprocessing(img,transform = None):
    if transforms is None:
        raise Exception("无transform进行预处理")
    img_tensor = transform(img)
    return img_tensor

def get_model(saved_model_path=major_config.path_saved_model,visual_model=False,input_size=(3,32,32)):
    net = major_config.model
    net.load_state_dict(torch.load(saved_model_path))
    if visual_model:
        summary(net, input_size=input_size, device="cpu")
    return net


if __name__ == "__main__":
    # 1. data
    img_path = r"D:\Classification_Demo\major_dataset_repo\split_data\test\0\0_116.png"
    # 2. model
    model_path = major_config.path_saved_model
    net = get_model(model_path,False,input_size=(3,32,32))
    net.to(major_config.device)
    net.eval()
    # 3.单图predict
    with torch.no_grad():
        # step 1/4 : path --> img
        img_rgb = Image.open(img_path).convert('RGB')

        # step 2/4 : img --> tensor
        img_tensor = preprocessing(img_rgb,inference_transform)
        img_tensor.unsqueeze_(0)
        img_tensor = img_tensor.to(major_config.device)

        # step 3/4 : tensor --> vector
        time_start = time.time()
        outputs = net(img_tensor)
        time_end = time.time()

        # step 4/4 : visualization
        print(outputs)
        _,pred_int = torch.max(outputs,1)
        print(pred_int)
        pred_str = classes[int(pred_int.cuda().data.cpu().numpy())]
        print(pred_str)

    # 4.多图预测
    with torch.no_grad():
        # step 1/4 : path --> img
        path = r"D:\Classification_Demo\major_dataset_repo\split_data\test\0"
        files_list = os.listdir(path)
        file_path_list = [os.path.join(path, img) for img in files_list]

        for i in range(100):
            img_rgb = Image.open(file_path_list[i]).convert('RGB')

            # step 2/4 : img --> tensor
            img_tensor = preprocessing(img_rgb,inference_transform)
            img_tensor.unsqueeze_(0)
            img_tensor = img_tensor.to(major_config.device)

            # step 3/4 : tensor --> vector
            time_start = time.time()
            outputs = net(img_tensor)
            time_end = time.time()
            print("所耗时间:",time_end - time_start)

            # step 4/4 : visualization
            print(outputs)
            _,pred_int = torch.max(outputs,1)
            print(pred_int)
            pred_str = classes[int(pred_int.cuda().data.cpu().numpy())]
            print(pred_str)

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值