resnet18预测代码

调用resnet18模型以及训练好的模型参数
测试图片,预测时的数据预处理操作一定要和训练时的数据预处理操作一致
不一致可能会导致最终预测结果不对

import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image

# 定义设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_classes = 3
# 定义预处理操作
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 加载模型
model = models.resnet18(weights=None)  # 不使用预训练的权重
model.fc = nn.Linear(512, n_classes)  # 假设 n_classes 是你的类别数
model.load_state_dict(torch.load("F:/resnet18/logs1/best_model.pth", map_location=device))  # 加载训练好的模型参数
model = model.to(device)
model.eval()  # 设置模型为评估模式,关闭 dropout 和 batch normalization

# 定义预测函数
def predict(image_path, model, preprocess):
    image = Image.open(image_path)
    image = preprocess(image).unsqueeze(0).to(device)  # 添加 batch 维度并移动到设备
    with torch.no_grad():  # 禁用梯度计算
        output = model(image)
        probabilities = torch.softmax(output, dim=1)
        predicted_class = torch.argmax(probabilities, dim=1).item()
        return predicted_class, probabilities[0, predicted_class].item()
dic1 = {
            0: '飞机', 1: '汽车', 2: '轿子'
        }  # 创建一个
# 示例用法
#image_path = "F:\cloud\sky1\sav1111\sky_2.jpg"  # 替换成你的图片路径 47次时预测成类别0卷云,实际情况类别为1,68次时预测为1正确,晴空和部分卷云图片相似
image_path = "F:/resnet18/dataset/test/stf/202402_17_073613.jpg"
predicted_class, confidence = predict(image_path, model, preprocess)
print("Predicted class:", dic1[predicted_class])
print("Confidence:", confidence)
  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值