分类 - 鸢尾花 分类模型的测试
flyfish
从文件中加载模型文件进行推理
可以根据预测结果和csv文件中的标签结果进行比对或者绘图等
import torch
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
def predict(Net,config, test_X):
# 获取测试数据
test_set = TensorDataset(test_X)
test_loader = DataLoader(test_set, batch_size=1)
# 加载模型
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = Net(config.input_dim,config.output_dim).to(device)
model.load_state_dict(torch.load(config.model_save_path + config.model_name)) # 加载模型参数
# 先定义一个tensor保存预测结果
result = torch.Tensor().to(device)
# 预测过程
model.eval()
for _data in test_loader:
data_X = _data[0].to(device)
pred_X = model(data_X)
pred_X = torch.max(pred_X, 1)[1]
result = torch.cat((result, pred_X))
return result.detach().cpu().numpy()
推理的直接结果
tensor([[-13.5357, -1.3558, 6.0250]])
tensor([[-7.7295, -0.2102, 2.1028]])
tensor([[-12.1690, -1.0299, 5.1819]])
tensor([[-11.8842, -0.8139, 4.5455]])
tensor([[-9.1968, -2.5394, 7.7020]])
tensor([[-6.3651, -1.3473, 4.4809]])
tensor([[-9.2604, -0.9389, 4.1727]])
tensor([[-8.4016, -1.6374, 5.5507]])
tensor([[-8.5926, -1.5390, 5.1352]])
tensor([[-7.9865, -2.6461, 7.5752]])
经过torch.max和torch.cat之后
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 2. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
标签含义是
Iris-setosa = 0
Iris-versicolor = 1
Iris-virginica = 2
解释torch.max
import torch
a = torch.randn(4, 3)
print(a)
# tensor([[-0.0459, -0.0215, 0.1406],
# [ 0.4759, 1.0764, -2.8144],
# [ 0.5207, -0.7760, 1.2992],
# [-1.0217, 2.0276, -0.7463]])
print(torch.max(a))
print(torch.max(a, 1)) #包括值和索引 每行最大,一共4行
print(torch.max(a, 1)[0]) #值
print(torch.max(a, 1)[1]) #索引
# tensor(2.0276)
# torch.return_types.max(
# values=tensor([0.1406, 1.0764, 1.2992, 2.0276]),
# indices=tensor([2, 1, 2, 1]))
# tensor([0.1406, 1.0764, 1.2992, 2.0276])
# tensor([2, 1, 2, 1])
print(torch.max(a, 0)) #每列最大,一共3列
print(torch.max(a, 0)[0])
print(torch.max(a, 0)[1])
# torch.return_types.max(
# values=tensor([0.5207, 2.0276, 1.2992]),
# indices=tensor([2, 3, 2]))
# tensor([0.5207, 2.0276, 1.2992])
# tensor([2, 3, 2])