【KnowledgeBase】基于Pytorch建立一个自定义的目标检测DataLoader


前言

代码和文件夹免费公开,学习自取。链接!链接!链接!

本文介绍如何通过torch建立一个自己的目标检测数据集DataLoader。以WIDERFACE的部分图片与YOLO格式标注为例。本文分为以下4步介绍建立DataLoader的整体思路,具体还是要根据自己的数据集File格式进行调整:

  1. 数据集File格式介绍
  2. 代码整体思路及展示
  3. 代码分块介绍
  4. 代码测试

一、数据集File格式介绍

我们使用了4张WIDERFACE中的图片以及YOLO格式的标签来进行说明,整体的数据结构如下图,其中用来测试使用的代码文件DIY_DataLoader.ipynb也在同一目录下。
在这里插入图片描述

  1. imgaes中存放.jpg图片;
    在这里插入图片描述

  2. labels中存放.txt的YOLO格式标注文件;
    在这里插入图片描述
    在这里插入图片描述

  3. DIY_DataLoader.ipynb是测试用的代码文件;

  4. train.txt中罗列了图片的路径。
    在这里插入图片描述


二、代码整体思路及展示

2.1 代码整体思路

自己的DIY的DataLoader需要重写其中的一些方法,主要包括:__int____len____getitem__

  • __int__中保存一些数据集相关信息,最终为了得到:每一张图片路径、每一个标注路径、对图片进行的transform;
  • __len__为了得到一共有多少张图片数量;
  • __getitem__为了得到其中某一张图片的[image_array, gt_bbox]

2.2 代码整体展示

import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
class WIDERFACE(Dataset):
    def __init__(self, root_dir, image_file, ann_file, ann_txt, transform=None):
        self.root_dir = root_dir        # Root file
        self.image_file = image_file    # Image file
        self.ann_file = ann_file        # Annotations file

        self.imagenames = self.load_imgnames(ann_txt)

        # Load imgs/annos file
        self.imgs = [f'{x}.jpg' for x in [os.path.join(root_dir, image_file, image) for image in self.imagenames]]
        self.annos = [f'{x}.txt' for x in [os.path.join(root_dir, ann_file, image) for image in self.imagenames]]

        self.transform = transform
    
    def __len__(self):
        return len(self.imagenames)
    
    def __getitem__(self, idx):
        image = np.array(Image.open(self.imgs[idx]).getdata())
        with open(self.annos[idx]) as f:
            gt_bbox = [x.strip('\n').split('/')[-1] for x in f.readlines()] # x, y, width, height
        sample = {'img': image, 'gt_bbox': gt_bbox}
        if self.transform:
            sample = self.transform(sample)
        return sample
    
    def load_imgnames(self, ann_txt):
        with open(ann_txt) as f:
            samples = [x.strip('\n').split('/')[-1] for x in f.readlines()]
            names = [x.split('.')[0] for x in samples]
        return names

三、代码分块介绍

这里将一块块地详细介绍下类中每一个方法的内容。

3.1 def load_imgnames

这块代码最终为了读取下每一张图片的名称,在我们的文件夹中,它的输入为train.txt

	def load_imgnames(self, ann_txt):
        with open(self, ann_txt) as f:
            samples = [x.strip('\n').split('/')[-1] for x in f.readlines()]
            names = [x.split('.')[0] for x in samples]
        return names

简单测试一下,就是
在这里插入图片描述

3.2 def _init_

这一块主要是保存并告诉一下DataLoader,图片文件的具体路径、图片标注框的具体路径、用了什么transform方法。

	def __init__(self, root_dir, image_file, ann_file, ann_txt, transform=None):
        self.root_dir = root_dir        # Root file         './'
        self.image_file = image_file    # Image file        'images/'
        self.ann_file = ann_file        # Annotations file  'labels/'

        self.imagenames = self.load_imgnames(ann_txt)   # 得到了每张图片的名称

        # 基于self.imagenames,得到每张图片的 imgs/annos 具体的路径
        self.imgs = [f'{x}.jpg' for x in [os.path.join(root_dir, image_file, image) for image in self.imagenames]]
        self.annos = [f'{x}.txt' for x in [os.path.join(root_dir, ann_file, image) for image in self.imagenames]]

        self.transform = transform

3.3 def _len_

self.imagenames是一个保存了所有图片名称的List,故使用len()方法可以知道一共有多少张图片。当然self.imagenames也可以替换成self.imgs或者self.annos,效果是一样的。

	def __len__(self):
        return len(self.imagenames)

3.4 def _getitem_

    def __getitem__(self, idx):
        # 根据图片路径打开图片并转化成np.array格式
        image = np.array(Image.open(self.imgs[idx]).getdata())
        # 保存图片对应的gt_bbox[x, y, width, height]
        with open(self.annos[idx]) as f:
            gt_bbox = [x.strip('\n').split('/')[-1] for x in f.readlines()]
        # 使用dict对一张图片的信息进行包装
        sample = {'img': image, 'gt_bbox': gt_bbox}
        if self.transform:
            sample = self.transform(sample)
        return sample

四、代码测试

我们使用这个由4张图片组成的数据集进行一下DIY_WIDERFACE这个DataLoader的代码测试。

root_file = './'
image_file = 'images/'
ann_file = 'labels/'
ann_txt = './train.txt'

test = DIY_WIDERFACE(root_file, image_file, ann_file, ann_txt)
  1. __init__方法中储藏的一些信息展示,如下:

在这里插入图片描述

  1. __len__方法表示的图片数量,如下:

在这里插入图片描述

  1. __getitem__方法展示某一张图片的信息,包括图片的数组信息、gt_bbox,如下:

在这里插入图片描述


总结

本文就简单地带大家理解下DataLoader的构造思路。
欢迎批评指正。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
以下是一个基于PyTorch的遥感图像目标检测算法代码,使用的是Faster R-CNN模型: ```python import torch import torchvision from torchvision.models.detection.faster_rcnn import FastRCNNPredictor def get_model(num_classes): # 加载预训练的 Faster R-CNN 模型 model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) # 替换分类器,使其适用于新的数据集 in_features = model.roi_heads.box_predictor.cls_score.in_features model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) return model # 定义数据集 class MyDataset(torch.utils.data.Dataset): def __init__(self, images, targets): self.images = images self.targets = targets def __getitem__(self, index): image = self.images[index] target = self.targets[index] # 转换为 PyTorch 张量 image = torch.tensor(image, dtype=torch.float32) target = { 'boxes': torch.tensor(target['boxes'], dtype=torch.float32), 'labels': torch.tensor(target['labels'], dtype=torch.int64) } return image, target def __len__(self): return len(self.images) # 训练模型 def train_model(model, dataloader, optimizer, criterion): model.train() for images, targets in dataloader: images = list(image for image in images) targets = [{k: v for k, v in t.items()} for t in targets] optimizer.zero_grad() loss_dict = model(images, targets) losses = sum(loss for loss in loss_dict.values()) losses.backward() optimizer.step() # 测试模型 def test_model(model, dataloader): model.eval() with torch.no_grad(): for images, targets in dataloader: images = list(image for image in images) targets = [{k: v for k, v in t.items()} for t in targets] outputs = model(images) # TODO: 对模型输出进行处理,得到目标检测结果 # 训练数据集 train_images = [...] train_targets = [...] # 测试数据集 test_images = [...] test_targets = [...] # 创建数据集 train_dataset = MyDataset(train_images, train_targets) test_dataset = MyDataset(test_images, test_targets) # 创建数据加载器 train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True) test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=2, shuffle=False) # 创建模型 model = get_model(num_classes=2) # 假设有两个类别,例如车辆和建筑物 # 定义优化器和损失函数 optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005) criterion = torch.nn.CrossEntropyLoss() # 训练模型 for epoch in range(10): train_model(model, train_dataloader, optimizer, criterion) # 测试模型 test_model(model, test_dataloader) ``` 需要注意的是,在上面的代码中,你需要根据你的具体数据集修改 `MyDataset` 类中的代码,以及根据你的具体需求修改测试模型函数中的代码。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Prymce-Q

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值