import os
import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from tqdm import tqdm
# 定义目录路径
test_dir = '/media/wagnchogn/data_disk/artifact/revise_cla_normal_artifact/dataset_test'
# 定义数据转换
data_transforms = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 自定义数据集类以包含路径
class ImageFolderWithPaths(datasets.ImageFolder):
def __getitem__(self, index):
original_tuple = super().__getitem__(index)
path = self.imgs[index][0]
return original_tuple + (path,)
# 创建数据集
test_dataset = ImageFolderWithPaths(test_dir, transform=data_transforms)
# 创建数据加载器
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
# 定义设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 加载预训练的ResNet18模型并加载最佳模型权重
model = models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 1) # 二分类任务输出1个节点
model.load_state_dict(torch.load('best_model_weights.pth')) # 加载保存的模型权重
model = model.to(device)
# 定义损失函数
criterion = nn.BCEWithLogitsLoss()
# 测试阶段
model.eval()
test_loss = 0.0
test_corrects = 0
# 保存a2n数据集中预测为normal且概率大于0.9的图像文件名
a2n_normal_filenames = []
with torch.no_grad():
for batch in tqdm(test_loader):
inputs, labels, paths = batch
inputs = inputs.to(device)
labels = labels.to(device).float().view(-1, 1)
outputs = model(inputs)
loss = criterion(outputs, labels)
test_loss += loss.item() * inputs.size(0)
probs = torch.sigmoid(outputs)
preds = probs >= 0.5
test_corrects += torch.sum(preds == labels.data)
# 检查a2n数据集中预测为normal且概率大于0.9的图像
for i in range(len(labels)):
if 'a2n' in paths[i] and preds[i] == 1 and probs[i] >= 0.9:
a2n_normal_filenames.append(paths[i])
test_epoch_loss = test_loss / len(test_dataset)
test_epoch_acc = test_corrects.double() / len(test_dataset)
print(f"Test Loss: {test_epoch_loss:.4f} Acc: {test_epoch_acc:.4f}")
# 打印a2n数据集中预测为normal且概率大于0.9的图像文件名
print("a2n数据集中预测为normal且概率大于0.9的图像文件名:")
for filename in a2n_normal_filenames:
print(filename)
test——gpt
最新推荐文章于 2024-10-09 14:16:04 发布