【Pytorch】18.创建自定义数据集并根据文件名或对应文件名的文本文件获取labels

源码

MNIST_Training_By_FileName_Dataset
MNIST_Training_By_TXTLabel

简介

本文主要探讨两种不同的数据集获取labels的方法

  • 根据图片的文件名中获取文件标签
    在这里插入图片描述

  • 根据与图片名称相同的.txt文件获取文件名
    在这里插入图片描述

根据图片名称获取labels

主要的区别在__init__方法中

    def __init__(self, root_path, train, transform=None):
        self.root_path = root_path
        self.transform = transform
        if train:
            self.root_path = os.path.join(self.root_path, 'training')
        else:
            self.root_path = os.path.join(self.root_path, 'testing')

        self.img_paths = []
        self.labels = []
        for label_path in os.listdir(self.root_path):
            img_path = os.path.join(self.root_path, label_path)
            if os.path.isdir(img_path):
                for img in os.listdir(img_path):
                    # 使用正则获取图片名称中的信息
                    match = re.search(r'_(\d+)', img)
                    label = match.group(1)
                    # print(f'label: {label}')
                    pre_img_path = os.path.join(img_path, img)
                    self.img_paths.append(pre_img_path)
                    self.labels.append(label)

我们可以看到在我们获取图片名称后,我们需要使用正则化来提取文件名中含有的label:xxx_0.png

根据txt文件获取labels

主要的区别在__getitem__方法


    def __getitem__(self, index):
        # ../datasets/mnist_png/training/.../1.png
        img = self.imgs[index]

        # 仅获取文件名
        # 1.png
        img_name = os.path.basename(img)
        img = Image.open(img).convert('L')
        if self.transform is not None:
            img = self.transform(img)
        # ../datasets/mnist_png/labels/1.txt
        label_dir = os.path.join(self.label_path, img_name.replace('.png', '.txt'))
        # 从文件中获取内容
        with open(label_dir, 'r') as f:
            label = f.read().strip()
        return img, label

  1. 我们需要现根据图片的相对路径通过os.path.basename来获取文件名
  2. 然后根据图片名使用img_name.replace来将.png换成.txt然后在对应的labels文件夹下找到对应名称的文件来获取标签
  • 3
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
好的,那么我们就使用PyTorch内置的`ImageFolder`类来加载数据集。假设该数据集的路径为`./flower`,其中包含三个子文件夹,分别是`setosa`、`versicolor`、`virginica`,分别对应鸢尾花的三个品种。 以下是加载数据集的代码: ```python import torch import torchvision.transforms as transforms from torch.utils.data import DataLoader from torchvision.datasets import ImageFolder # 定义数据预处理 data_transforms = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # 加载数据集 dataset = ImageFolder('./flower', transform=data_transforms) # 定义数据加载器 batch_size = 32 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) ``` 其中,`data_transforms`定义了数据预处理的方式,包括将图片resize到$224 \times 224$,转成tensor,以及进行标准化处理。 `dataset`使用`ImageFolder`类加载数据集,其中`./flower`指定了数据集的路径,`transform=data_transforms`指定了数据预处理的方式。 `data_loader`定义了数据加载器,`batch_size`指定了每个batch的大小,`shuffle=True`指定了在每个epoch开始时是否对数据进行shuffle。 接下来,我们可以使用`torchvision.models.resnet18()`方法初始化网络,并手动将下载下来的权重给模型参数赋值。假设该权重文件的路径为`./resnet18_weights.pth`,我们可以使用以下代码加载权重: ```python from torchvision.models import resnet18 # 初始化模型 model = resnet18(num_classes=3) # 加载权重 weights = torch.load('./resnet18_weights.pth', map_location=torch.device('cpu')) model.load_state_dict(weights) ``` 其中,`resnet18(num_classes=3)`中的`num_classes=3`指定了模型输出的类别数,即鸢尾花数据集的三个品种。 `torch.load()`方法可以加载权重文件,其中`map_location=torch.device('cpu')`指定了将权重加载到CPU上。 最后,我们可以对鸢尾花进行分类训练。以下是训练代码: ```python # 定义损失函数和优化器 criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # 训练模型 num_epochs = 10 for epoch in range(num_epochs): for i, (inputs, labels) in enumerate(data_loader): # 前向传播 outputs = model(inputs) loss = criterion(outputs, labels) # 反向传播和优化 optimizer.zero_grad() loss.backward() optimizer.step() # 每训练10个batch,输出一次信息 if (i + 1) % 10 == 0: print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' .format(epoch+1, num_epochs, i+1, len(data_loader), loss.item())) ``` 其中,`torch.nn.CrossEntropyLoss()`定义了损失函数,`torch.optim.SGD()`定义了优化器,`lr=0.001`指定了学习率,`momentum=0.9`指定了动量大小。 在训练过程中,我们需要对每个batch进行前向传播、计算损失、反向传播和优化。完成一次epoch后,我们可以输出一次信息,包括当前epoch的编号、当前batch的编号、当前损失的大小等信息。 训练完成后,我们可以保存模型的参数,以便后续使用。以下是保存模型参数的代码: ```python # 保存模型参数 torch.save(model.state_dict(), './resnet18_iris.pth') ``` 其中,`model.state_dict()`返回了模型的参数字典,`torch.save()`可以将该字典保存到文件中,以便后续使用。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值