dataset_path = './data3_dir/'
dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
# 划分数据集为训练集、测试集和验证集
train_size = int(0.8 * len(dataset))
#val_size = int(0.1 * len(dataset))
test_size=len(dataset) - train_size
#训练集、验证集、测试集比例分为8.1.1
train_dataset,test_dataset = random_split(dataset, [train_size,test_size])
# 创建数据加载器
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
#val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 输出训练集和测试集的数量
print(f"训练集数量: {len(train_dataset)}")
print(f"测试集数量: {len(test_dataset)}")
# 在训练循环之前检查测试数据
for images, labels in train_loader:
print("Train data shape:", images.shape)
print("Train labels shape:", labels.shape)
break
# 在测试循环之前检查测试数据
for images, labels in test_loader:
print("Test data shape:", images.shape)
print("Test labels shape:", labels.shape)
break