split_train_val_test

import os
import shutil
import random

def split_dataset(data_dir, train_dir, test_dir, val_dir, train_ratio=0.8, test_ratio=0.1, val_ratio=0.1):
    # 确保训练集、测试集和验证集目录存在
    os.makedirs(train_dir, exist_ok=True)
    os.makedirs(test_dir, exist_ok=True)
    os.makedirs(val_dir, exist_ok=True)
    
    # 获取所有类别
    classes = os.listdir(data_dir)
    
    for class_name in classes:
        # 创建类别目录
        class_dir = os.path.join(data_dir, class_name)
        if not os.path.isdir(class_dir):
            continue
        
        # 创建训练集、测试集和验证集的类别目录
        train_class_dir = os.path.join(train_dir, class_name)
        test_class_dir = os.path.join(test_dir, class_name)
        val_class_dir = os.path.join(val_dir, class_name)
        os.makedirs(train_class_dir, exist_ok=True)
        os.makedirs(test_class_dir, exist_ok=True)
        os.makedirs(val_class_dir, exist_ok=True)
        
        # 获取该类别下所有图片
        images = os.listdir(class_dir)
        random.shuffle(images)
        
        # 计算训练集、测试集和验证集的大小
        total_images = len(images)
        train_size = int(total_images * train_ratio)
        test_size = int(total_images * test_ratio)
        
        # 划分图片到训练集、测试集和验证集
        train_images = images[:train_size]
        test_images = images[train_size:train_size + test_size]
        val_images = images[train_size + test_size:]
        
        # 移动图片到训练集、测试集和验证集目录
        for image in train_images:
            shutil.copy2(os.path.join(class_dir, image), os.path.join(train_class_dir, image))
        
        for image in test_images:
            shutil.copy2(os.path.join(class_dir, image), os.path.join(test_class_dir, image))
        
        for image in val_images:
            shutil.copy2(os.path.join(class_dir, image), os.path.join(val_class_dir, image))
    
    print("数据集划分完成")

# 设置目录路径
data_dir = 'dataset'
train_dir = 'train'
test_dir = 'test'
val_dir = 'val'

# 划分数据集
split_dataset(data_dir, train_dir, test_dir, val_dir, train_ratio=0.8, test_ratio=0.1, val_ratio=0.1)
import os
import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from tqdm import tqdm

# 定义目录路径
test_dir = '/media/dell/data_4t/artidiffu/data_train/test'

# 定义数据转换
data_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# 自定义数据集类以包含路径
class ImageFolderWithPaths(datasets.ImageFolder):
    def __getitem__(self, index):
        original_tuple = super().__getitem__(index)
        path = self.imgs[index][0]
        return original_tuple + (path,)

# 创建数据集
test_dataset = ImageFolderWithPaths(test_dir, transform=data_transforms)

# 创建数据加载器
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

# 定义设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 加载预训练的ResNet18模型并加载最佳模型权重
model = models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 1)  # 二分类任务输出1个节点
model.load_state_dict(torch.load('best_model_weights.pth'))  # 加载保存的模型权重
model = model.to(device)

# 定义损失函数
criterion = nn.BCEWithLogitsLoss()

# 测试阶段
model.eval()
test_loss = 0.0
test_corrects = 0

# 保存分类精度大于0.9的图像文件名
correct_filenames = []

with torch.no_grad():
    for inputs, labels, paths in tqdm(test_loader):
        inputs = inputs.to(device)
        labels = labels.to(device).float().view(-1, 1)

        outputs = model(inputs)
        loss = criterion(outputs, labels)

        test_loss += loss.item() * inputs.size(0)
        probs = torch.sigmoid(outputs)
        preds = probs >= 0.5

        test_corrects += torch.sum(preds == labels.data)

        # 检查分类精度大于0.9的图像
        for i in range(len(labels)):
            if preds[i] == labels.data[i] and probs[i] >= 0.9:
                correct_filenames.append(paths[i])

test_epoch_loss = test_loss / len(test_dataset)
test_epoch_acc = test_corrects.double() / len(test_dataset)

print(f"Test Loss: {test_epoch_loss:.4f} Acc: {test_epoch_acc:.4f}")

# 打印分类精度大于0.9的图像文件名
print("图像文件名(分类精度大于0.9):")
for filename in correct_filenames:
    print(filename)

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值