分类 - 鸢尾花 分类模型的测试

文章展示了如何使用PyTorch加载和测试鸢尾花分类模型。通过TensorDataset和DataLoader处理测试数据,加载预训练模型,然后使用torch.max获取每批数据的最大值和索引,进行预测并得出结果。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

分类 - 鸢尾花 分类模型的测试

flyfish

从文件中加载模型文件进行推理
可以根据预测结果和csv文件中的标签结果进行比对或者绘图等

import torch
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

def predict(Net,config, test_X):
    # 获取测试数据
    test_set = TensorDataset(test_X)
    test_loader = DataLoader(test_set, batch_size=1)

    # 加载模型
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    model = Net(config.input_dim,config.output_dim).to(device)
    model.load_state_dict(torch.load(config.model_save_path + config.model_name))   # 加载模型参数

    # 先定义一个tensor保存预测结果
    result = torch.Tensor().to(device)

    # 预测过程
    model.eval()

    for _data in test_loader:
        data_X = _data[0].to(device)
        pred_X = model(data_X)
        pred_X = torch.max(pred_X, 1)[1] 
        result = torch.cat((result, pred_X))

    return result.detach().cpu().numpy()

推理的直接结果

tensor([[-13.5357,  -1.3558,   6.0250]])
tensor([[-7.7295, -0.2102,  2.1028]])
tensor([[-12.1690,  -1.0299,   5.1819]])
tensor([[-11.8842,  -0.8139,   4.5455]])
tensor([[-9.1968, -2.5394,  7.7020]])
tensor([[-6.3651, -1.3473,  4.4809]])
tensor([[-9.2604, -0.9389,  4.1727]])
tensor([[-8.4016, -1.6374,  5.5507]])
tensor([[-8.5926, -1.5390,  5.1352]])
tensor([[-7.9865, -2.6461,  7.5752]])

经过torch.max和torch.cat之后

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 2. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]

标签含义是

Iris-setosa = 0
Iris-versicolor = 1
Iris-virginica = 2

解释torch.max

import torch
a = torch.randn(4, 3)
print(a)
# tensor([[-0.0459, -0.0215,  0.1406],
#         [ 0.4759,  1.0764, -2.8144],
#         [ 0.5207, -0.7760,  1.2992],
#         [-1.0217,  2.0276, -0.7463]])
print(torch.max(a))
print(torch.max(a, 1)) #包括值和索引  每行最大,一共4行
print(torch.max(a, 1)[0]) #值
print(torch.max(a, 1)[1]) #索引

# tensor(2.0276)
# torch.return_types.max(
# values=tensor([0.1406, 1.0764, 1.2992, 2.0276]),
# indices=tensor([2, 1, 2, 1]))
# tensor([0.1406, 1.0764, 1.2992, 2.0276])
# tensor([2, 1, 2, 1])



print(torch.max(a, 0)) #每列最大,一共3列
print(torch.max(a, 0)[0])
print(torch.max(a, 0)[1])
# torch.return_types.max(
# values=tensor([0.5207, 2.0276, 1.2992]),
# indices=tensor([2, 3, 2]))
# tensor([0.5207, 2.0276, 1.2992])
# tensor([2, 3, 2])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

二分掌柜的

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值