inference.py篇

inference.py 篇

目录:

  • 前言
  • 思考自己需要载入的超参
  • 书写代码
  • 函数手册

前言

在该模块中加载训练好的模型,对测试集的image进行推理。

思考自己需要载入的超参

该模块的书写,是train的简约版,例如你可能需要设置和train相同的batch_sizedevicedataloader等信息,但是这次你不需要设置epoch等信息,对模型的参数进行优化等。

书写代码

书写顺序如下:

argparse()方法收集需要传递的所有参数,传入main函数中(可选)。

main函数中思路如下:

  1. 写路径等信息
  2. 书写dataloder。设置transformsdatasetdataloaderbatch_size等参数,因为dataloader中要用到。
  3. 设置其余超参,如device等,这次你必须要加载train中产生的预训练权重。
  4. 对测试集进行推理

下以AlexNet中的inference.py为例:

# add path
import os, sys
root_path = os.path.dirname(os.path.dirname(__file__))
project_path = os.path.dirname(__file__)
sys.path.append(project_path)
# add module
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json
import torch
import numpy as np
from model import AlexNet


def parse_args():
	"""get your args"""
	
def convert_image(image_path:str = ""):
    """transform png to jpg"""

def main():
    # 路径
    root_path       = os.path.dirname(os.path.dirname(__file__))
    project_path    = os.path.dirname(__file__)
    weight_path     = os.path.join(root_path, "weight", "AlexNet_2.pth")
    image_path      = "/home/yingmuzhi/AlexNet/daisy.jpg"
    # 加载预测图片
    img             = None
    data_transform  = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    img = Image.open(image_path)
    print(np.array(img).shape)
    img = data_transform(img)   # 只接受[height, width, channel=3]的图片, 即RGB的jpg
    img = torch.unsqueeze(img, dim = 0) # 传入网络需要[batch, channel, height, width]
    # 加载json文件
    try:
        json_file = open(project_path + "/class_indices.json","r")
        class_indict = json.load(json_file)
    except Exception as e:
        print(e)
        exit(-1)
    # 测试参数
    net = AlexNet(num_classes=2)
    net.load_state_dict(torch.load(weight_path))
    net.eval()    # 关闭dropout层并且不会梯度回传
    with torch.no_grad():
        # predict class
        output = net(img)
        # print(output.shape)
        output = torch.squeeze(output)
        # print(output.shape)
        predict = torch.softmax(output, dim = 0)
        # print(predict.shape)
        predict_cla = torch.argmax(predict).numpy()
    print(class_indict[str(predict_cla)], predict[predict_cla].item())

if __name__ == "__main__":
	args = parse_args()
    main(args)

函数手册

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值