📚博客主页:knighthood2001
✨公众号:认知up吧 (目前正在带领大家一起提升认知,感兴趣可以来围观一下)
🎃知识星球:【认知up吧|成长|副业】介绍
❤️如遇文章付费,可先看看我公众号中是否发布免费文章❤️
🙏笔者水平有限,欢迎各位大佬指点,相互学习进步!
模型测试,模型测试,首先就是需要模型,然后,需要数据才能进行测试。
def test_model_process(model, test_dataloader):
# 设定测试所用到的设备,有GPU用GPU没有GPU用CPU
device = "cuda" if torch.cuda.is_available() else 'cpu'
# 讲模型放入到训练设备中
model = model.to(device)
# 初始化参数
test_corrects = 0.0
test_num = 0
# 只进行前向传播计算,不计算梯度,从而节省内存,加快运行速度
with torch.no_grad():
for test_data_x, test_data_y in test_dataloader:
# 将特征放入到测试设备中
test_data_x = test_data_x.to(device)
# 将标签放入到测试设备中
test_data_y = test_data_y.to(device)
# 设置模型为评估模式
model.eval()
# 前向传播过程,输入为测试数据集,输出为对每个样本的预测值
output= model(test_data_x)
# 查找每一行中最大值对应的行标
pre_lab = torch.argmax(output, dim=1)
# 如果预测正确,则准确度test_corrects加1
test_corrects += torch.sum(pre_lab == test_data_y.data)
# 将所有的测试样本进行累加
test_num += test_data_x.size(0)
# 计算测试准确率
test_acc = test_corrects.double().item() / test_num
print("测试的准确率为:", test_acc)
因此,定义了test_model_process(model, test_dataloader):
函数,通过传入模型和数据集。这里内容和前面的差别在于没有了enumerate函数,
for test_data_x, test_data_y in test_dataloader:
在之前的例子中,在循环
for step, (b_x, b_y) in enumerate(train_dataloader):
enumerate 函数被用来同时获取数据加载器 train_dataloader
产生的每个批次(batch)的索引(step)和该批次的内容(b_x 和 b_y)。虽然在那个特定的代码段中,step 可能没有被直接使用.
因此,这里其实也可以用,也可以不用上面的表达。
此外,由于是测试,不涉及到反向传播,我们不需要计算梯度。把他关掉可以加快计算速度。
打印预测和真实标签
batch_size设置为1的好处,就在于可以直接打印预测值和真实值。
def print_test_model(model, test_dataloader):
# 设定测试所用到的设备,有GPU用GPU没有GPU用CPU
device = "cuda" if torch.cuda.is_available() else 'cpu'
model = model.to(device)
classes = ['猫', '狗']
with torch.no_grad():
for b_x, b_y in test_dataloader:
b_x = b_x.to(device)
b_y = b_y.to(device)
# 设置模型为验证模型
model.eval()
output = model(b_x)
pre_lab = torch.argmax(output, dim=1)
result = pre_lab.item()
label = b_y.item()
print("预测值:", classes[result], "------", "真实值:", classes[label])
在以前的文章中,pre_lab和b_y都是包含批数数量的数据,比如16个数据。
现在由于批数变成1了,因此其里面都只有1个值。可以直接取出来进行对比。
全部代码
import torch
import torch.utils.data as Data
from torchvision import transforms
from model import GoogLeNet, Inception
from torchvision.datasets import ImageFolder
from PIL import Image
def test_data_process():
# 定义数据集的路径
ROOT_TRAIN = r'data\test'
normalize = transforms.Normalize([0.162, 0.151, 0.138], [0.058, 0.052, 0.048])
# 定义数据集处理方法变量
test_transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), normalize])
# 加载数据集
test_data = ImageFolder(ROOT_TRAIN, transform=test_transform)
test_dataloader = Data.DataLoader(dataset=test_data,
batch_size=1,
shuffle=True,
num_workers=0)
return test_dataloader
def test_model_process(model, test_dataloader):
# 设定测试所用到的设备,有GPU用GPU没有GPU用CPU
device = "cuda" if torch.cuda.is_available() else 'cpu'
# 讲模型放入到训练设备中
model = model.to(device)
# 初始化参数
test_corrects = 0.0
test_num = 0
# 只进行前向传播计算,不计算梯度,从而节省内存,加快运行速度
with torch.no_grad():
for test_data_x, test_data_y in test_dataloader:
# 将特征放入到测试设备中
test_data_x = test_data_x.to(device)
# 将标签放入到测试设备中
test_data_y = test_data_y.to(device)
# 设置模型为评估模式
model.eval()
# 前向传播过程,输入为测试数据集,输出为对每个样本的预测值
output = model(test_data_x)
# 查找每一行中最大值对应的行标
pre_lab = torch.argmax(output, dim=1)
# 如果预测正确,则准确度test_corrects加1
test_corrects += torch.sum(pre_lab == test_data_y.data)
# 将所有的测试样本进行累加
test_num += test_data_x.size(0)
# 计算测试准确率
test_acc = test_corrects.double().item() / test_num
print("测试的准确率为:", test_acc)
def print_test_model(model, test_dataloader):
# 设定测试所用到的设备,有GPU用GPU没有GPU用CPU
device = "cuda" if torch.cuda.is_available() else 'cpu'
model = model.to(device)
classes = ['猫', '狗']
with torch.no_grad():
for b_x, b_y in test_dataloader:
b_x = b_x.to(device)
b_y = b_y.to(device)
# 设置模型为验证模型
model.eval()
output = model(b_x)
pre_lab = torch.argmax(output, dim=1)
result = pre_lab.item()
label = b_y.item()
print("预测值:", classes[result], "------", "真实值:", classes[label])
if __name__ == "__main__":
# 加载模型
model = GoogLeNet(Inception, in_channels=3, out_channels=2)
model.load_state_dict(torch.load('best_model.pth'))
# 利用现有的模型进行模型的测试
test_dataloader = test_data_process()
test_model_process(model, test_dataloader)
# print_test_model()
最后
如果你从前面的文章,一直看过来,我相信,这里你也是一下就能理解的。