【YOLOv5】datasets.py


  • 训练数据前需要对数据集进行处理,这里利用torch.utils.data下的两个类,DataLoader和Dataset
  • Dataset是封封装图像和标签,每次输出一张图像和对应的标签,[数据增强也在此实现]
  • DataLoader 指明了Dataset和batchsize shuffle 和 collate_fn

在这里插入图片描述

1.Dataset

1.1 基本结构

# dataset.py
from torch.utils.data import DataLoader,Dataset

class LoadImagesAndLables(Dataset):
    def __init__(self,img_path):
    	# 从文件中读取图像 和 标签,并解析
        self.imgs = img_path
    def __len__(self):
        # 数据集大小
        return len(self.imgs)
    def __getitem__(self,index):
        # 每次根据index 返回指定的图像和标签
        return torch.from_numpy(imgs),torch.from_numpy(labels)

1.2 实现基本功能

  • 初步实现Dataset功能,获取图像,解析标签
  • 图像返回格式为RGB ,通道为CWH,并且归一化
  • 标签size= (标签数量,6) ,其中6列分别为 图像索引,类别,x,y,w,h
  • collate_fn 是DataLoader中用到的,目的是对同一个batch内的数据进行打包,图像利用torch.stack(img,0)转换成(16,3,512,512)。标签利用torch.cat(label,0)转换成(N,6),其中N 为当前batch下的所有标注数量
import torchvision.datasets
import torch
from torch.utils.data import DataLoader,Dataset
from pathlib import Path
import numpy as np
import cv2
"""
***Dataset***
创建Dataset子类,用来创建image 和 label
"""

class LoadImagesAndLabels(torch.utils.data.Dataset):
    # 主要作用就是读取图像和标签 同时完成数据增强
    """
    # v1.0  2021.12.30  by wjl
    # 1. 输入图像txt 和 标签 txt ,获取所有图像和标签
    # 2. 默认输入图像和标签大小为 512  即不用调整
    # 3. 输出图像为tensor 类型 标签shape=[box_num,6]  [[img_index,cls,x,y,w,h],[img_index,cls,x,y,w,h],...]
    """
    def __init__(self,txt_path:str,img_size = 512):
        # 初始化一些参数
        self.img_txt_path  = txt_path
        self.imgs_path     = None
        self.labels_path   = None
        self.labels = {}   #{"img_name":[[],[]...]}
        self.get_imgs_and_labels()

        pass
    def __len__(self):
        # 返回数据集大小
        return len(self.imgs_path)
    def __getitem__(self,index):
        # 返回当前索引下的图像和标签
        # print(index)
        img_name = self.imgs_path[index]
        # print(img_name)
        img = cv2.imread(img_name)
        img = img[:,:,::-1].transpose(2,0,1) # BGR -> RGB   WHC -> CWH
        img = img/255.0
        label = self.labels[img_name]
        label_out = torch.zeros((len(label),6)) # 增加一列 保存图片序号  与图片对应
        label_out[:,1:] = torch.from_numpy(label)
        return torch.from_numpy(img), label_out

    def get_imgs_and_labels(self):
        # 根据TXT文件,获取所有图像路径和标签
        # ===============get image path==================
        p = Path(self.img_txt_path)
        assert p.suffix == ".txt" , "image path not txt"
        with open(self.img_txt_path,'r') as f:
            self.imgs_path = f.readlines()
        self.imgs_path = [img_name.strip() for img_name in self.imgs_path]
        assert self.imgs_path ,"No image found!"
        # ===============get label =====================
        # 图像和标签在同一路径下所以转换一下图像路径的后缀即可
        self.labels_path = [lp.replace('.jpg','.txt') for lp in self.imgs_path]
        for img_name in self.imgs_path:
            label_path = img_name.replace('.jpg','.txt')
            with open(label_path,'r') as f:
                local = np.array([x.split() for x in f.read().strip().splitlines()],dtype = np.float32)
                # print(local)
            self.labels[img_name] = local
    @staticmethod
    def collate_fn(batch):
        img,label = zip(*batch)
        for i,l in enumerate(label):
            l[:,0] = i
        return torch.stack(img,0), torch.cat(label,0)

2. DataLoader

  • DataLoader是在train中用到的,每次提供batchsize大小的数据,同时利用shuffle进行图像顺序打乱。

    path = "/home/.../ImageSets/Main/train.txt"
    dataset = LoadImagesAndLabels(path)
    # print(dataset[10])
    train_loader = DataLoader(dataset = dataset,batch_size=16,shuffle=True,collate_fn=dataset.collate_fn)
    for epoch in range(2):
        for i, data in enumerate(train_loader):
            img,label = data
            print("epoch: {}, {} inputs size {} labels size {}".format(epoch,i,img.size(),label.size()))
  • 运行后结果
epoch: 0, 0 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([261, 6])
epoch: 0, 1 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([274, 6])
epoch: 0, 2 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([265, 6])
epoch: 0, 3 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([281, 6])
epoch: 0, 4 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([294, 6])
epoch: 0, 5 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([278, 6])
epoch: 0, 6 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([283, 6])
epoch: 0, 7 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([261, 6])
epoch: 0, 8 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([285, 6])
epoch: 0, 9 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([267, 6])
epoch: 0, 10 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([229, 6])
epoch: 0, 11 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([275, 6])
epoch: 0, 12 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([284, 6])
epoch: 0, 13 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([257, 6])
epoch: 0, 14 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([287, 6])
epoch: 0, 15 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([274, 6])
epoch: 0, 16 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([257, 6])
epoch: 0, 17 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([257, 6])
epoch: 0, 18 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([236, 6])
epoch: 0, 19 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([254, 6])
epoch: 0, 20 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([290, 6])
epoch: 0, 21 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([259, 6])
epoch: 0, 22 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([285, 6])
epoch: 0, 23 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([257, 6])
epoch: 0, 24 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([279, 6])
epoch: 0, 25 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([225, 6])
epoch: 0, 26 inputs size torch.Size([3, 3, 512, 512]) labels size torch.Size([43, 6])
epoch: 1, 0 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([263, 6])
epoch: 1, 1 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([293, 6])
epoch: 1, 2 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([264, 6])
epoch: 1, 3 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([265, 6])
epoch: 1, 4 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([305, 6])
epoch: 1, 5 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([264, 6])
epoch: 1, 6 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([226, 6])
epoch: 1, 7 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([255, 6])
epoch: 1, 8 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([247, 6])
epoch: 1, 9 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([303, 6])
epoch: 1, 10 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([279, 6])
epoch: 1, 11 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([247, 6])
epoch: 1, 12 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([257, 6])
epoch: 1, 13 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([255, 6])
epoch: 1, 14 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([298, 6])
epoch: 1, 15 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([259, 6])
epoch: 1, 16 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([258, 6])
epoch: 1, 17 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([227, 6])
epoch: 1, 18 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([244, 6])
epoch: 1, 19 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([256, 6])
epoch: 1, 20 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([260, 6])
epoch: 1, 21 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([296, 6])
epoch: 1, 22 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([273, 6])
epoch: 1, 23 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([315, 6])
epoch: 1, 24 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([282, 6])
epoch: 1, 25 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([267, 6])
epoch: 1, 26 inputs size torch.Size([3, 3, 512, 512]) labels size torch.Size([39, 6])

3. 训练数据

  • 利用DataLoader获取batchsize个数据后,送入网络进行推理,获得的结果与label进行损失计算,根据优化策略,反向传播梯度,更新一次weight,实现小批次的训练过程。
  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

wangxiaobei2017

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值