import os
import shutil
import random
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from PIL import Image
import pandas as pd
test_dir = '/media/wagnchogn/data_disk/artifact/revise_cla_normal_artifact/test_a2n'
new_path = '/media/wagnchogn/data_disk/artifact/revise_cla_normal_artifact/test_a2n_sel'
if not os.path.exists(new_path):
os.makedirs(new_path)
#output_dir = '/media/wagnchogn/data_disk/artifact/revise_cla_normal_artifact/test_a2n'
#a2n_dir = '/media/wagnchogn/data_disk/artifact/revise_cla_normal_artifact/a2n'
# 定义数据转换
data_transforms = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 定义设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 加载预训练的ResNet18模型
model = models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 1) # 二分类任务输出1个节点
model.load_state_dict(torch.load('best_model_weights.pth'))
model = model.to(device)
# 定义损失函数和优化器
criterion = nn.BCEWithLogitsLoss()
img_predictions = []
imgs = os.listdir(test_dir)
pred_imgs = []
pred_labels = []
pred_logits = []
for img_name in tqdm(imgs):
img_path = os.path.join(test_dir, img_name)
img = Image.open(img_path)
# 进行预测
input = data_transforms(img).unsqueeze(0).to(device)
model.eval()
with torch.no_grad():
output = model(input)
pred = torch.sigmoid(output)
if pred <= 0.5:
pred_label = 0
else:
pred_label = 1
logit = pred.cpu().numpy()[0][0]
pred_imgs.append(img_path)
pred_labels.append(pred_label)
pred_logits.append(logit)
if pred_label == 1:
shutil.copy(img_path, os.path.join(new_path, img_name))
df = pd.DataFrame({'img_path': pred_imgs, 'pred_label': pred_labels, 'logits': pred_logits})
df.to_csv('test_info.csv')
test3333333
最新推荐文章于 2024-08-01 08:17:37 发布