假设数据目录结构是data_dir/images
包含图像文件,data_dir/labels
包含对应的标签文件,并且图像和标签的文件名是匹配的。
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
# 定义一个名为CustomDataset的类,继承自torch.utils.data.Dataset,用于自定义数据集
class CustomDataset(Dataset):
def __init__(self, data_dir, transform=None):
# 存储数据集的目录路径
self.data_dir = data_dir
# 存储图像和标签的预处理/变换操作
self.transform = transform
# 获取数据目录下"images"文件夹中的所有图像文件名,并存储在self.images列表中
self.images = os.listdir(os.path.join(data_dir, "images"))
# 获取数据目录下"labels"文件夹中的所有标签文件名,并存储在self.labels列表中
self.labels = os.listdir(os.path.join(data_dir, "labels"))
# 注意:这里假设图像和标签的文件名是一一对应的
# 定义__len__方法,返回数据集的大小
def __len__(self):
# 返回self.images列表的长度,即图像的数量
return len(self.images)
# 定义__getitem__方法,根据索引idx返回一个数据样本(图像+对应的标签)
def __getitem__(self, idx):
# 根据索引idx从self.images和self.labels列表中获取图像和标签的文件名,并拼接成完整的文件路径
image_path = os.path.join(self.data_dir, "images", self.images[idx])
label_path = os.path.join(self.data_dir, "labels", self.labels[idx])
# 使用PIL库加载图像文件,并将其转换为RGB格式(三通道彩色图像)
image = Image.open(image_path).convert('RGB')
# 使用PIL库加载标签文件,并将其转换为L格式(单通道灰度图像),这里假设标签是灰度图
label = Image.open(label_path).convert('L')
# 如果定义了预处理/变换操作,则对图像和标签应用这些操作
# 注意:在实际应用中,图像和标签可能需要不同的预处理/变换操作
if self.transform:
image = self.transform(image)
label = self.transform(label)
# 返回变换后的图像和标签作为一个数据样本
return image, label
接下来,我们使用CustomDataset
类来创建训练集和数据加载器(DataLoader
):
# 定义变换
transform = transforms.Compose([
transforms.Resize((64, 64)), # 调整图像大小到64x64
transforms.ToTensor(), # 将PIL图像转换为tensor
# 添加其他必要的变换...
])
# 创建训练集实例
train_dataset = CustomDataset(data_dir="path_to_your_data", transform=transform)
# 创建数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
使用训练集来训练模型:
# 定义你的模型
model = ...
# 定义损失函数和优化器
criterion = ...
optimizer = ...
# 训练模型
num_epochs = 10 # 设置训练的epoch数量
for epoch in range(num_epochs):
for images, labels in train_loader:
# 将数据发送到设备(CPU或GPU)上
images, labels = images.to(device), labels.to(device)
# 前向传播
outputs = model(images)
# 计算损失
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad() # 清空之前的梯度
loss.backward() # 反向传播,计算当前梯度
optimizer.step() # 更新权重
# 打印统计信息
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')