PyTorch构建训练集

假设数据目录结构是data_dir/images包含图像文件,data_dir/labels包含对应的标签文件,并且图像和标签的文件名是匹配的。

import torch    
from torch.utils.data import Dataset, DataLoader    
from torchvision import transforms    
from PIL import Image    
import os  
  
# 定义一个名为CustomDataset的类,继承自torch.utils.data.Dataset,用于自定义数据集  
class CustomDataset(Dataset):    
    def __init__(self, data_dir, transform=None):    
        # 存储数据集的目录路径  
        self.data_dir = data_dir    
        # 存储图像和标签的预处理/变换操作  
        self.transform = transform    

        # 获取数据目录下"images"文件夹中的所有图像文件名,并存储在self.images列表中  
        self.images = os.listdir(os.path.join(data_dir, "images"))    
        # 获取数据目录下"labels"文件夹中的所有标签文件名,并存储在self.labels列表中  
        self.labels = os.listdir(os.path.join(data_dir, "labels"))    
        # 注意:这里假设图像和标签的文件名是一一对应的 
    
    # 定义__len__方法,返回数据集的大小
    def __len__(self):    
        # 返回self.images列表的长度,即图像的数量  
        return len(self.images)    
    
    # 定义__getitem__方法,根据索引idx返回一个数据样本(图像+对应的标签)  
    def __getitem__(self, idx):    
        # 根据索引idx从self.images和self.labels列表中获取图像和标签的文件名,并拼接成完整的文件路径  
        image_path = os.path.join(self.data_dir, "images", self.images[idx])    
        label_path = os.path.join(self.data_dir, "labels", self.labels[idx])    
            
        # 使用PIL库加载图像文件,并将其转换为RGB格式(三通道彩色图像)  
        image = Image.open(image_path).convert('RGB')    
        # 使用PIL库加载标签文件,并将其转换为L格式(单通道灰度图像),这里假设标签是灰度图  
        label = Image.open(label_path).convert('L')    
            
        # 如果定义了预处理/变换操作,则对图像和标签应用这些操作  
        # 注意:在实际应用中,图像和标签可能需要不同的预处理/变换操作
        if self.transform:    
            image = self.transform(image) 
            label = self.transform(label)  
    
        # 返回变换后的图像和标签作为一个数据样本  
        return image, label

接下来,我们使用CustomDataset类来创建训练集和数据加载器(DataLoader):

# 定义变换
transform = transforms.Compose([  
    transforms.Resize((64, 64)),  # 调整图像大小到64x64  
    transforms.ToTensor(),  # 将PIL图像转换为tensor  
    # 添加其他必要的变换...  
])  
  
# 创建训练集实例  
train_dataset = CustomDataset(data_dir="path_to_your_data", transform=transform)  
  
# 创建数据加载器  
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)

使用训练集来训练模型:

# 定义你的模型
model = ...  
  
# 定义损失函数和优化器
criterion = ...  
optimizer = ...  
  
# 训练模型  
num_epochs = 10  # 设置训练的epoch数量  
for epoch in range(num_epochs):  
    for images, labels in train_loader:  
        # 将数据发送到设备(CPU或GPU)上  
        images, labels = images.to(device), labels.to(device)  
          
        # 前向传播  
        outputs = model(images)  
          
        # 计算损失  
        loss = criterion(outputs, labels)  
          
        # 反向传播和优化  
        optimizer.zero_grad()  # 清空之前的梯度  
        loss.backward()  # 反向传播,计算当前梯度  
        optimizer.step()  # 更新权重  
          
        # 打印统计信息
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值