用自己的数据集训练的resnet50模型进行分类预测,pth2onnx,测试onnx模型

1、批量预测图片并保存结果图,用best_model.pth预测,测试代码如下

import os  #导入必要的库:os: 用于文件路径操作。
import json #json: 用于读取类别标签的 JSON 文件。
import torch  #torch: PyTorch 深度学习库。
from PIL import Image  #Image:用于图像操作的 PIL 库。
from torchvision import transforms  #transforms:PyTorch 中用于图像预处理的模块。
import matplotlib.pyplot as plt  #matplotlib.pyplot: 用于图像展示。
from model import resnet50
from model import resnet34

def main():  #定义 main() 函数作为程序的主入口。
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #检查并设置设备为 GPU 或 CPU。
#定义图像的预处理操作 data_transform,包括调整大小、中心裁剪、转换为张量以及归一化。
    data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    # load image 加载待预测的图像,并进行预处理操作。其中:
    # img_path = "./1.jpg" # img_path 是待预测的图像文件路径。
    predict_path = "D:/project/ResNet50/predict"
    folder_path = os.path.join(predict_path + "/test/zhou")
    assert os.path.exists(folder_path), "file: '{}' dose not exist.".format(folder_path)

    # read class_indict
    json_path = 'D:/project/ResNet50/class_indices.json' #读取预定义类别标签的 JSON 文件 class_indices.json。
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # create model
    # model = resnet34(num_classes=2).to(device)  #resnet34(num_classes=5).to(device)表示用的resnet34训的2类
    model = resnet50(num_classes=2).to(device)  #resnet50(num_classes=2).to(device)表示用的resnet50训的2类。

    # load model weights
    # weights_path = "./resNet50.pth"
    weights_path = "D:/project/ResNet50/runs/train/best_model.pth"  # 创建 ResNet50 模型,并根据需要加载已训练好的模型参数文件。
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    model.load_state_dict(torch.load(weights_path, map_location=device))

    # Prediction for each image in the folder
    for filename in os.listdir(folder_path):
        img_path = os.path.join(folder_path,filename)
        assert os.path.isfile(img_path), "File '{}' dose not exist.".format(img_path)

        img = Image.open(img_path)
        img_transformed = data_transform(img).unsqueeze(0).to(device)
    # prediction 执行预测操作
        model.eval() #将模型设置为评估模式。
        with torch.no_grad():  #使用 torch.no_grad() 禁用梯度计算,因为在预测阶段不需要计算梯度。
        # predict class,
            output = torch.squeeze(model(img_transformed)).cpu()  #通过模型进行前向传播预测,
            predict = torch.softmax(output, dim=0)  #并对输出进行 softmax 归一化以获取预测类别的概率。
            # predict_cla = torch.argmax(predict).numpy()
            predict_cla = torch.argmax(predict).item()
            class_name = class_indict[str(predict_cla)]
            #在图像上绘制预测结果
        plt.imshow(img)
        plt.axis('off')
        plt.title(f'Predicted class:{class_name}')

        #保存带有预测结果的图像
        result_folder = os.path.join(predict_path,"result/zhou")
        if not os.path.exists(result_folder):
            os.makedirs(result_folder)
        result_path = os.path.join(result_folder,filename)
        plt.savefig(result_path,bbox_inches='tight',pad_inches=0.1)
        plt.close()

        print(f"File: '{filename}' - Predicted class: {class_name}.Result saved to:{result_path}")

if __name__ == '__main__':
    main()

2、pth2onnx代码如下:

import torch
from model import resnet50  # 导入模型结构
import onnx

print(torch.__version__)

pth_filename = './best_model.pth'  # 训练好的权重
onnx_filename = './best_model.onnx'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = resnet50(num_classes=2).to(device)  # num_classes 根据你的类别个数调整
net.load_state_dict(torch.load(pth_filename, map_location=device))
net.eval()
dummy_input = torch.randn(1, 3, 224, 224, device=device)
torch.onnx.export(net, dummy_input, onnx_filename,
                  input_names=['input'], output_names=['ouput'],
                  export_params=True, verbose=False, opset_version=12,
                  dynamic_axes={'input': {0: "batch_size"},
                                'output': {0: "batch_size"}})

# 检查一下生成的onnx
test = onnx.load("best_model.onnx")
onnx.checker.check_model(test)
print("==> Passed")

3、批量预测文件夹中所有图像并保存预测结果的完整代码,用best_model.onnx模型预测

代码如下:

import onnxruntime
import numpy as np
import cv2
import os
import json
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

# 加载ONNX模型
onnx_model_path = "D:/project/Resnet50/best_model.onnx"
ort_session = onnxruntime.InferenceSession(onnx_model_path)

#定义图像的预处理操作
data_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

# 加载类别标签
json_path = 'D:/project/Resnet50/class_indices.json'
assert os.path.exists(json_path), "File '{}' dose not exist.".format(json_path)

with open(json_path, "r") as f:
    class_indict = json.load(f)

# 推理函数
def infer(image_path):
    # 预处理图像
    image = Image.open(image_path)
    img_transformed = data_transform(image).unsqueeze(0)

    # 执行推理
    ort_inputs = {ort_session.get_inputs()[0].name: img_transformed.numpy()}
    ort_outs = ort_session.run(None, ort_inputs)

    # 解析结果
    predictions = ort_outs[0]
    predicted_class = np.argmax(predictions, axis=1)[0]
    confidence = predictions[0][predicted_class]
    class_name = class_indict[str(predicted_class)]

    return class_name, confidence

# 测试图像的路径
test_image_path = "D:/project/Resnet50/zhou"

#设置保存预测结果的文件夹路径
save_path = "D:/project/Resnet50/result"

#创建保存预测结果的文件夹
if not os.path.exists(save_path):
    os.makedirs(save_path)

#遍历文件夹中的所有图像文件
for filename in os.listdir(test_image_path):
    #构造图像文件的完整路径
    image_path = os.path.join(test_image_path,filename)

    #执行推理
    predicted_class,confidence = infer(image_path)

    #构造保存预测结果的文件路径
    result_filename = f"{filename.split('.')[0]}_predicted_{predicted_class}.jpg"
    result_path = os.path.join(save_path,result_filename)

    #将预测结果绘制在图像上并保存
    img = Image.open(image_path)
    plt.imshow(img)
    plt.axis('off')
    plt.title(f'Predicted class:{predicted_class}, Confidence:{confidence}')
    plt.savefig(result_path,bbox_inches='tight',pad_inches=0.1)
    plt.close()

    print(f"Image '{filename} prediction saved to '{result_path}")

  • 3
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值