import glob
import math
import os
import random
import shutil
import time
from pathlib import Path
from threading import Thread
import cv2
import numpy as np
import torch
from PIL import Image, ExifTags
from torch.utils.data import Dataset
from tqdm import tqdm
from utils.utils import xyxy2xywh, xywh2xyxy
help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.dng']
vid_formats = ['.mov', '.avi', '.mp4', '.mpg', '.mpeg', '.m4v', '.wmv', '.mkv']
# Get orientation exif tag
for orientation in ExifTags.TAGS.keys():
if ExifTags.TAGS[orientation] == 'Orientation':
break
# 此函数根据图片的信息获取图片的宽、高信息
def exif_size(img):
# Returns exif-corrected PIL size
s = img.size # (width, height)
try:
rotation = dict(img._getexif().items())[orientation]
if rotation == 6: # rotation 270
s = (s[1], s[0])
elif rotation == 8: # rotation 90
s = (s[1], s[0])
except:
pass
return s
# 根据LoadImagesAndLabels创建dataloader
def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False):
"""
参数解析:
path:包含图片路径的txt文件或者包含图片的文件夹路径
imgsz:网络输入图片大小
batch_size: 批次大小
stride:网络下采样最大总步长
opt:调用train.py时传入的参数,这里主要用到opt.single_cls,是否是单类数据集
hyp:网络训练时的一些超参数,包括学习率等,这里主要用到里面一些关于数据增强(旋转、平移等)的系数
augment:是否进行数据增强
cache:是否提前缓存图片到内存,以便加快训练速度
pad:设置矩形训练的shape时进行的填充
rect:是否进行矩形训练
"""
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
augment=augment, # augment images
hyp=hyp, # augmentation hyperparameters
rect=rect, # rectangular training
cache_images=cache,
single_cls=opt.single_cls,
stride=int(stride),
pad=pad)
batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size,
num_workers=nw,
pin_memory=True,
collate_fn=LoadImagesAndLabels.collate_fn)
return dataloader, dataset
class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
cache_images=False, single_cls=False, stride=32, pad=0.0):
try:
f = []
for p in path if isinstance(path, list) else [path]:
# 获取数据集路径path,包含图片路径的txt文件或者包含图片的文件夹路径
# 使用pathlib.Path生成与操作系统无关的路径,因为不同操作系统路径的‘/’会有所不同
p = str(Path(p)) # os-agnostic
# 获取数据集路径的上级父目录,os.sep为路径里的破折号(不同系统路径破折号不同,os.sep根据系统自适应)
parent = str(Path(p).parent) + os.sep
# 如果路径path为包含图片路径的txt文件
if os.path.isfile(p): # file
with open(p, 'r') as t:
# 获取图片路径,更换相对路径
t = t.read().splitlines()
f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
# 如果路径path为包含图片的文件夹路径
elif os.path.isdir(p): # folder
f += glob.iglob(p + os.sep + '*.*')
else:
raise Exception('%s does not exist' % p)
path = p # *.npy dir
# 破折号替换为os.sep,os.path.splitext(x)将文件名与扩展名分开并返回一个列表
self.img_files = [x.replace('/', os.sep) for x in f if os.path.splitext(x)[-1].lower() in img_formats]
except Exception as e:
raise Exception('Error loading data from %s: %s\nSee %s' % (path, e, help_url))
# 数据集的数量
n = len(self.img_files)
assert n > 0, 'No images found in %s. See %s' % (path, help_url)
# 获取batch的索引
bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
# 一个轮次batch的数量
nb = bi[-1] + 1 # number of batches
self.n = n # number of images
self.batch = bi # batch index of image
self.img_size = img_size # 输入图片分辨率大小
self.augment = augment # 数据增强
self.hyp = hyp # 超参数
self.image_weights = image_weights # 图片采样
self.rect = False if image_weights else rect # 矩形训练
self.mosaic = self.augment and not self.rect # mosaic数据增强
self.mosaic_border = [-img_size // 2, -img_size // 2] # mosaic增强的边界
self.stride = stride # 模型下采样的总步长
# 获取数据集的标签
self.label_files = [x.replace('images', 'labels').replace(os.path.splitext(x)[-1], '.txt')
for x in self.img_files]
# 保存图片shape的路径
sp = path.replace('.txt', '') + '.shapes' # shapefile path
try:
# 如果存在该路径,则读取
with open(sp, 'r') as f: # read existing shapefile
s = [x.split() for x in f.read().splitlines()]
assert len(s) == n, 'Shapefile out of sync'
except:
# 如果不存在,则读取图片shape再保存
s = [exif_size(I