需要修改部分以标出
import os
import json
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
//修改模型导入//
from get_F1_Recall.models.alexnet import AlexNet as create_model
from sklearn import metrics
import time
from PIL import Image
import torch.nn.functional as F
class CustomDataset(Dataset):
def __init__(self, data_root, transform=None):
self.data_root = data_root
self.transform = transform
self.image_paths = []
self.labels = []
self.load_data()
def load_data(self):
class_dirs = os.listdir(self.data_root)
class_dirs.sort()
for i, class_dir in enumerate(class_dirs):
class_path = os.path.join(self.data_root, class_dir)
if os.path.isdir(class_path):
for image_name in os.listdir(class_path):
image_path = os.path.join(class_path, image_name)
print("image_path: ", image_path)
self.image_paths.append(image_path)
self.labels.append(i)
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
img = Image.open(img_path)
if self.transform:
img = self.transform(img)
label = self.labels[idx]
return img, label
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
rand_res_crop = transforms.RandomResizedCrop(224)
//修改transform//
data_transform = transforms.Compose([transforms.Resize([256, 256]),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
//修改data_root//
data_root = "/home/ubuntu/???/new_?_data"
batch_size = 32 # 设置批大小
dataset = CustomDataset(data_root, transform=data_transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
y_pred = []
y_true = []
T1 = time.time()
# create model
model = create_model().to(device)
//修改权重加载//
model_weight_path = "weights/AlexNet_best.pth"
model.load_state_dict(torch.load(model_weight_path, map_location=device))
model.eval()
for images, labels in dataloader:
images = images.to(device)
labels = labels.to(device)
img_224 = F.interpolate(images, size=(224,224), mode='bilinear', align_corners=False).to(device)
with torch.no_grad():
//修改output//
output = torch.squeeze(model(images, img_224)).cpu()
predict = torch.softmax(output, dim=1)
predict_cla = torch.argmax(predict, dim=1).numpy()
y_pred.extend(predict_cla)
y_true.extend(labels.cpu().numpy())
T2 = time.time()
print("\nUsed time: ", T2 - T1)
print("Acc")
acc = metrics.accuracy_score(y_true, y_pred)
print(acc)
print("Recall")
recall = metrics.recall_score(y_true, y_pred, average='micro')
print(recall)
print("F1")
f1 = metrics.f1_score(y_true, y_pred, average='weighted')
print(f1)
print("Precision")
prec = metrics.precision_score(y_true, y_pred, average='weighted')
print(prec)
print("Confusion Matrix")
metrics_res = metrics.confusion_matrix(y_true, y_pred)
print(metrics_res)
if __name__ == '__main__':
main()