matlab ud.save,YOLOV5训练与测试时数据加载dataset.py代码注释与解析

本文详细解析了在Python中使用PyTorch训练YOLOv5时,如何通过`ud.save`处理数据集,包括数据加载、数据增强、创建dataloader等步骤。讲解了`LoadImagesAndLabels`类的功能,以及`create_dataloader`函数中涉及的参数解析、图片和标签的处理。同时,还介绍了数据增强的实现,如随机旋转、平移、缩放等操作。
摘要由CSDN通过智能技术生成

import glob

import math

import os

import random

import shutil

import time

from pathlib import Path

from threading import Thread

import cv2

import numpy as np

import torch

from PIL import Image, ExifTags

from torch.utils.data import Dataset

from tqdm import tqdm

from utils.utils import xyxy2xywh, xywh2xyxy

help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'

img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.dng']

vid_formats = ['.mov', '.avi', '.mp4', '.mpg', '.mpeg', '.m4v', '.wmv', '.mkv']

# Get orientation exif tag

for orientation in ExifTags.TAGS.keys():

if ExifTags.TAGS[orientation] == 'Orientation':

break

# 此函数根据图片的信息获取图片的宽、高信息

def exif_size(img):

# Returns exif-corrected PIL size

s = img.size # (width, height)

try:

rotation = dict(img._getexif().items())[orientation]

if rotation == 6: # rotation 270

s = (s[1], s[0])

elif rotation == 8: # rotation 90

s = (s[1], s[0])

except:

pass

return s

# 根据LoadImagesAndLabels创建dataloader

def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False):

"""

参数解析:

path:包含图片路径的txt文件或者包含图片的文件夹路径

imgsz:网络输入图片大小

batch_size: 批次大小

stride:网络下采样最大总步长

opt:调用train.py时传入的参数,这里主要用到opt.single_cls,是否是单类数据集

hyp:网络训练时的一些超参数,包括学习率等,这里主要用到里面一些关于数据增强(旋转、平移等)的系数

augment:是否进行数据增强

cache:是否提前缓存图片到内存,以便加快训练速度

pad:设置矩形训练的shape时进行的填充

rect:是否进行矩形训练

"""

dataset = LoadImagesAndLabels(path, imgsz, batch_size,

augment=augment, # augment images

hyp=hyp, # augmentation hyperparameters

rect=rect, # rectangular training

cache_images=cache,

single_cls=opt.single_cls,

stride=int(stride),

pad=pad)

batch_size = min(batch_size, len(dataset))

nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers

dataloader = torch.utils.data.DataLoader(dataset,

batch_size=batch_size,

num_workers=nw,

pin_memory=True,

collate_fn=LoadImagesAndLabels.collate_fn)

return dataloader, dataset

class LoadImagesAndLabels(Dataset): # for training/testing

def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,

cache_images=False, single_cls=False, stride=32, pad=0.0):

try:

f = []

for p in path if isinstance(path, list) else [path]:

# 获取数据集路径path,包含图片路径的txt文件或者包含图片的文件夹路径

# 使用pathlib.Path生成与操作系统无关的路径,因为不同操作系统路径的‘/’会有所不同

p = str(Path(p)) # os-agnostic

# 获取数据集路径的上级父目录,os.sep为路径里的破折号(不同系统路径破折号不同,os.sep根据系统自适应)

parent = str(Path(p).parent) + os.sep

# 如果路径path为包含图片路径的txt文件

if os.path.isfile(p): # file

with open(p, 'r') as t:

# 获取图片路径,更换相对路径

t = t.read().splitlines()

f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path

# 如果路径path为包含图片的文件夹路径

elif os.path.isdir(p): # folder

f += glob.iglob(p + os.sep + '*.*')

else:

raise Exception('%s does not exist' % p)

path = p # *.npy dir

# 破折号替换为os.sep,os.path.splitext(x)将文件名与扩展名分开并返回一个列表

self.img_files = [x.replace('/', os.sep) for x in f if os.path.splitext(x)[-1].lower() in img_formats]

except Exception as e:

raise Exception('Error loading data from %s: %s\nSee %s' % (path, e, help_url))

# 数据集的数量

n = len(self.img_files)

assert n > 0, 'No images found in %s. See %s' % (path, help_url)

# 获取batch的索引

bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index

# 一个轮次batch的数量

nb = bi[-1] + 1 # number of batches

self.n = n # number of images

self.batch = bi # batch index of image

self.img_size = img_size # 输入图片分辨率大小

self.augment = augment # 数据增强

self.hyp = hyp # 超参数

self.image_weights = image_weights # 图片采样

self.rect = False if image_weights else rect # 矩形训练

self.mosaic = self.augment and not self.rect # mosaic数据增强

self.mosaic_border = [-img_size // 2, -img_size // 2] # mosaic增强的边界

self.stride = stride # 模型下采样的总步长

# 获取数据集的标签

self.label_files = [x.replace('images', 'labels').replace(os.path.splitext(x)[-1], '.txt')

for x in self.img_files]

# 保存图片shape的路径

sp = path.replace('.txt', '') + '.shapes' # shapefile path

try:

# 如果存在该路径,则读取

with open(sp, 'r') as f: # read existing shapefile

s = [x.split() for x in f.read().splitlines()]

assert len(s) == n, 'Shapefile out of sync'

except:

# 如果不存在,则读取图片shape再保存

s = [exif_size(I

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值