import os
import shutil
from tqdm import tqdm
import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from PIL import Image
import pandas as pd
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np
test_dir = '/media/wagnchogn/data_disk/artifact/revise_cla_normal_artifact/test_a2n'
new_path = '/media/wagnchogn/data_disk/artifact/revise_cla_normal_artifact/test_a2n_sel'
if not os.path.exists(new_path):
os.makedirs(new_path)
# 定义数据转换
data_transforms = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 定义设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 加载预训练的ResNet18模型
model = models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 1) # 二分类任务输出1个节点
model.load_state_dict(torch.load('best_model_weights.pth'))
model = model.to(device)
img_predictions = []
imgs = os.listdir(test_dir)
pred_imgs = []
pred_labels = []
pred_logits = []
for img_name in tqdm(imgs):
img_path = os.path.join(test_dir, img_name)
img = Image.open(img_path)
# 进行预测
input = data_transforms(img).unsqueeze(0).to(device)
model.eval()
with torch.no_grad():
output = model(input)
logits = output.cpu().numpy().flatten() # 获取logits
pred = torch.sigmoid(output)
if pred <= 0.5:
pred_label = 0
else:
pred_label = 1
pred_imgs.append(img_path)
pred_labels.append(pred_label)
pred_logits.append(logits)
if pred_label == 1:
shutil.copy(img_path, os.path.join(new_path, img_name))
df = pd.DataFrame({'img_path': pred_imgs, 'pred_label': pred_labels})
logits_df = pd.DataFrame(pred_logits)
df = pd.concat([df, logits_df], axis=1)
df.to_csv('test_info.csv', index=False)
# 提取logits和标签
logits = df.iloc[:, 3:].values
labels = df['pred_label'].values
# 计算t-SNE
tsne = TSNE(n_components=2, random_state=42)
logits_tsne = tsne.fit_transform(logits)
# 绘制t-SNE图
plt.figure(figsize=(10, 8))
for label in np.unique(labels):
plt.scatter(logits_tsne[labels == label, 0], logits_tsne[labels == label, 1], label=f'Label {label}', alpha=0.5)
plt.legend()
plt.title('t-SNE of Image Logits')
plt.xlabel('t-SNE Component 1')
plt.ylabel('t-SNE Component 2')
plt.savefig('tsne_plot.png')
plt.show()
tsne111111
最新推荐文章于 2024-09-04 20:18:34 发布