ultralytics/data/dataset.py 代码阅读
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import json
from collections import defaultdict
from itertools import repeat
from multiprocessing.pool import ThreadPool
from pathlib import Path
import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data import ConcatDataset
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
from ultralytics.utils.ops import resample_segments
from ultralytics.utils.torch_utils import TORCHVISION_0_18
from .augment import (
Compose,
Format,
Instances,
LetterBox,
RandomLoadText,
classify_augmentations,
classify_transforms,
v8_transforms,
)
from .base import BaseDataset
from .utils import (
HELP_URL,
LOGGER,
get_hash,
img2label_paths,
load_dataset_cache_file,
save_dataset_cache_file,
verify_image,
verify_image_label,
)
# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
DATASET_CACHE_VERSION = "1.0.3"
# 用于加载 YOLO 格式的数据集,并支持多种任务(如目标检测、分割、关键点检测和 OBB 检测)。
# 它还提供了缓存标签、数据增强和数据批处理的功能。
class YOLODataset(BaseDataset):
"""
Dataset class for loading object detection and/or segmentation labels in YOLO format.
Args:
data (dict, optional): A dataset YAML dictionary. Defaults to None.
task (str): An explicit arg to point current task, Defaults to 'detect'.
Returns:
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
"""
# 根据任务类型(task)设置是否使用分割(segments)、关键点(keypoints)或 OBB(obb)。
# 保存数据集的 YAML 配置(data)。
def __init__(self, *args, data=None, task="detect", **kwargs):
"""Initializes the YOLODataset with optional configurations for segments and keypoints."""
self.use_segments = task == "segment"
self.use_keypoints = task == "pose"
self.use_obb = task == "obb"
self.data = data
# 断言:不能同时使用分割和关键点, 确保不能同时启用分割和关键点检测。
assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
super().__init__(*args, **kwargs)
# 缓存数据集的标签,检查图像并读取其形状。
def cache_labels(self, path=Path("./labels.cache")):
"""
Cache dataset labels, check images and read shapes.
Args:
path (Path): Path where to save the cache file. Default is Path("./labels.cache").
Returns:
(dict): labels.
"""
# 初始化统计变量和消息列表。
x = {"labels": []}
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
total = len(self.im_files)
# 如果任务是关键点检测,检查 kpt_shape 配置是否正确
nkpt, ndim = self.data.get("kpt_shape", (0, 0))
if self.use_keypoints and (nkpt <= 0 or ndim not in {2, 3}):
raise ValueError(
"'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
"keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'"
)
# 使用多线程池(ThreadPool)并行验证图像和标签。
with ThreadPool(NUM_THREADS) as pool:
results = pool.imap(
# 调用 verify_image_label 函数检查每个图像和标签文件。
func=verify_image_label,
iterable=zip(
self.im_files,
self.label_files,
repeat(self.prefix),
repeat(self.use_keypoints),
repeat(len(self.data["names"])),
repeat(nkpt),
repeat(ndim),
),
)
pbar = TQDM(results, desc=desc, total=total)
# 收集验证结果并更新统计信息
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
nm += nm_f
nf += nf_f
ne += ne_f
nc += nc_f
if im_file:
x["labels"].append(
{
"im_file": im_file,
"shape": shape,
"cls": lb[:, 0:1], # n, 1
"bboxes": lb[:, 1:], # n, 4
"segments": segments,
"keypoints": keypoint,
"normalized": True,
"bbox_format": "xywh",
}
)
if msg:
msgs.append(msg)
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
pbar.close()
if msgs:
LOGGER.info("\n".join(msgs))
if nf == 0:
LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}")
x["hash"] = get_hash(self.label_files + self.im_files)
x["results"] = nf, nm, ne, nc, len(self.im_files)
x["msgs"] = msgs # warnings
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
return x
# 加载数据集的标签信息,支持从缓存文件中读取或重新生成缓存。
def get_labels(self):
"""Returns dictionary of labels for YOLO training."""
# 使用 img2label_paths 将图像路径转换为标签路径。
self.label_files = img2label_paths(self.im_files)
# 定义缓存文件路径。
cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
# 尝试加载缓存文件
try:
cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
# 如果缓存文件不存在或版本不匹配,则重新生成缓存。
except (FileNotFoundError, AssertionError, AttributeError):
cache, exists = self.cache_labels(cache_path), False # run cache ops
# Display cache 显示缓存信息。
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
if exists and LOCAL_RANK in {-1, 0}:
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
# 如果存在消息,记录警告信息。
if cache["msgs"]:
LOGGER.info("\n".join(cache["msgs"])) # display warnings
# Read cache 从缓存中读取标签信息
[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
labels = cache["labels"]
# 如果没有找到任何标签,记录警告信息。
if not labels:
LOGGER.warning(f"WARNING ⚠️ No images found in {cache_path}, training may not work correctly. {HELP_URL}")
# 更新图像文件路径
self.im_files = [lb["im_file"] for lb in labels] # update im_files
# Check if the dataset is all boxes or all segments
# 检查数据集是否包含边界框和分割信息。
lengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels)
len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))
# 如果边界框和分割信息的数量不一致,记录警告信息并移除分割信息。
if len_segments and len_boxes != len_segments:
LOGGER.warning(
f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, "
f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. "
"To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset."
)
for lb in labels:
lb["segments"] = []
# 如果没有找到任何类别信息,记录警告信息。
if len_cls == 0:
LOGGER.warning(f"WARNING ⚠️ No labels found in {cache_path}, training may not work correctly. {HELP_URL}")
return labels
# 构建并追加数据增强变换
def build_transforms(self, hyp=None):
"""Builds and appends transforms to the list."""
# 如果启用数据增强,设置 Mosaic 和 Mixup 的比例
if self.augment:
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
# 调用 v8_transforms 构建数据增强变换
transforms = v8_transforms(self, self.imgsz, hyp)
else: # 如果不启用数据增强,仅使用 LetterBox 变换。
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
# 使用 Format 变换格式化数据,支持边界框、分割、关键点和 OBB。
# 根据任务类型设置是否返回分割、关键点和 OBB 信息
transforms.append(
Format(
bbox_format="xywh",
normalize=True,
return_mask=self.use_segments,
return_keypoint=self.use_keypoints,
return_obb=self.use_obb,
batch_idx=True,
mask_ratio=hyp.mask_ratio,
mask_overlap=hyp.overlap_mask,
bgr=hyp.bgr if self.augment else 0.0, # only affect training.
)
)
return transforms
# 关闭 Mosaic、Copy-Paste 和 Mixup 数据增强,并重新构建变换。
def close_mosaic(self, hyp):
"""Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations."""
hyp.mosaic = 0.0 # set mosaic ratio=0.0
hyp.copy_paste = 0.0 # keep the same behavior as previous v8 close-mosaic
hyp.mixup = 0.0 # keep the same behavior as previous v8 close-mosaic
# 调用 build_transforms 方法重新构建变换
self.transforms = self.build_transforms(hyp)
# 更新标签信息,支持自定义标签格式。
def update_labels_info(self, label):
"""
Custom your label format here.
Note:
cls is not with bboxes now, classification and semantic segmentation need an independent cls label
Can also support classification and semantic segmentation by adding or removing dict keys there.
"""
# 提取边界框、分割、关键点、边界框格式和归一化信息
bboxes = label.pop("bboxes")
segments = label.pop("segments", [])
keypoints = label.pop("keypoints", None)
bbox_format = label.pop("bbox_format")
normalized = label.pop("normalized")
# NOTE: do NOT resample oriented boxes
# 如果任务是 OBB 检测,设置分割重采样数量为 100;否则设置为 1000。
segment_resamples = 100 if self.use_obb else 1000
# 如果有分割信息,确保分割的长度不超过 segment_resamples。
# 使用 resample_segments 重采样分割信息。
if len(segments) > 0:
# make sure segments interpolate correctly if original length is greater than segment_resamples
max_len = max(len(s) for s in segments)
segment_resamples = (max_len + 1) if segment_resamples < max_len else segment_resamples
# list[np.array(segment_resamples, 2)] * num_samples
segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0)
else:
segments = np.zeros((0, segment_resamples, 2), dtype=np.float32)
# 创建 Instances 对象,包含边界框、分割、关键点、边界框格式和归一化信息
label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
return label
# 批处理:通过 collate_fn 方法将数据样本批量化。
@staticmethod
def collate_fn(batch):
"""Collates data samples into batches."""
# 初始化新批次字典。提取批次中的所有键和值。
new_batch = {}
keys = batch[0].keys()
values = list(zip(*[list(b.values()) for b in batch]))
# 处理每个键的值
for i, k in enumerate(keys):
value = values[i]
# 如果键是 img,使用 torch.stack 将图像堆叠成一个批次。
if k == "img":
value = torch.stack(value, 0)
# 如果键是 masks、keypoints、bboxes、cls、segments 或 obb,使用 torch.cat 将值拼接成一个批次。
if k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}:
value = torch.cat(value, 0)
new_batch[k] = value
# 处理 batch_idx
new_batch["batch_idx"] = list(new_batch["batch_idx"]) # 将 batch_idx 转换为列表。
for i in range(len(new_batch["batch_idx"])):
# 为每个目标图像索引添加偏移量。
new_batch["batch_idx"][i] += i # add target image index for build_targets()
# 使用 torch.cat 将 batch_idx 拼接成一个批
new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
return new_batch
# 继承自 YOLODataset,用于加载 YOLO 格式的数据集,并支持多模态模型训练(结合图像和文本)。
class YOLOMultiModalDataset(YOLODataset):
"""
Dataset class for loading object detection and/or segmentation labels in YOLO format.
Args:
data (dict, optional): A dataset YAML dictionary. Defaults to None.
task (str): An explicit arg to point current task, Defaults to 'detect'.
Returns:
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
"""
def __init__(self, *args, data=None, task="detect", **kwargs):
"""Initializes a dataset object for object detection tasks with optional specifications."""
super().__init__(*args, data=data, task=task, **kwargs)
# 在父类的基础上,为多模态模型训练添加文本信息
def update_labels_info(self, label):
"""Add texts information for multi-modal model training."""
# 调用父类的 update_labels_info 方法,获取基础的标签信息
labels = super().update_labels_info(label)
# NOTE: some categories are concatenated with its synonyms by `/`.
# 从数据集的类别名称中提取文本信息。
# 假设类别名称中可能包含用 / 分隔的同义词,将它们分割成列表。
# 例如,如果类别名称是 "cat/dog",则分割为 ["cat", "dog"]。
labels["texts"] = [v.split("/") for _, v in self.data["names"].items()]
return labels
# 在父类的基础上,为多模态训练增加文本增强变换
def build_transforms(self, hyp=None):
"""Enhances data transformations with optional text augmentation for multi-modal training."""
# 调用父类的 build_transforms 方法,获取基础的数据增强变换
transforms = super().build_transforms(hyp)
# 如果启用数据增强(self.augment 为 True),在变换列表中插入 RandomLoadText。
if self.augment:
# NOTE: hard-coded the args for now.
# RandomLoadText 是一个自定义的数据增强变换,用于加载随机文本样本
# max_samples 参数限制了最大样本数量,取数据集类别数量(self.data["nc"])和 80 的较小值。
# padding 参数设置为 True,表示对文本进行填充。
transforms.insert(-1, RandomLoadText(max_samples=min(self.data["nc"], 80), padding=True))
return transforms
# 继承自 YOLODataset,用于处理目标检测任务,特别支持从指定的 JSON 文件中加载标注信息。
# 它主要用于“grounding”任务,即将自然语言描述与图像中的对象进行对齐
class GroundingDataset(YOLODataset):
"""Handles object detection tasks by loading annotations from a specified JSON file, supporting YOLO format."""
def __init__(self, *args, task="detect", json_file, **kwargs):
"""Initializes a GroundingDataset for object detection, loading annotations from a specified JSON file."""
# 确保任务类型为 detect,因为当前版本仅支持目标检测任务
assert task == "detect", "`GroundingDataset` only support `detect` task for now!"
# 保存标注文件的路径
self.json_file = json_file
# 调用父类 YOLODataset 的初始化方法,传入空的 data 字典
super().__init__(*args, task=task, data={}, **kwargs)
# 这个方法在父类中用于获取图像文件路径,
# 但在 GroundingDataset 中,图像文件路径将在 get_labels 方法中读取,因此这里返回空列表。
def get_img_files(self, img_path):
"""The image files would be read in `get_labels` function, return empty list here."""
return []
def get_labels(self):
"""Loads annotations from a JSON file, filters, and normalizes bounding boxes for each image."""
labels = [] # 初始化标签列表。
LOGGER.info("Loading annotation file...") # 记录加载标注文件的日志信息。
# 打开并加载 JSON 文件,解析为 Python 对象
with open(self.json_file) as f:
annotations = json.load(f)
images = {f"{x['id']:d}": x for x in annotations["images"]}
# 使用 defaultdict 构建图像 ID 到标注的映射。
img_to_anns = defaultdict(list)
for ann in annotations["annotations"]:
img_to_anns[ann["image_id"]].append(ann)
# 遍历每个图像的标注信息,使用 TQDM 显示进度条
for img_id, anns in TQDM(img_to_anns.items(), desc=f"Reading annotations {self.json_file}"):
img = images[f"{img_id:d}"]
# 获取图像的高度、宽度和文件名
h, w, f = img["height"], img["width"], img["file_name"]
im_file = Path(self.img_path) / f
# 检查图像文件是否存在,如果不存在则跳过。
if not im_file.exists():
continue
# 将图像路径添加到 self.im_files
self.im_files.append(str(im_file))
# 初始化边界框列表、类别到 ID 的映射和文本列表
bboxes = []
cat2id = {}
texts = []
# 跳过 iscrowd 标注。
for ann in anns:
if ann["iscrowd"]:
continue
box = np.array(ann["bbox"], dtype=np.float32)
# 将边界框从 [x, y, w, h] 转换为 [x_center, y_center, w, h]。
box[:2] += box[2:] / 2
# 将边界框坐标归一化到 [0, 1] 范围。
box[[0, 2]] /= float(w)
box[[1, 3]] /= float(h)
# 跳过宽度或高度为零的边界框。
if box[2] <= 0 or box[3] <= 0:
continue
# 从图像的描述中提取类别名称
caption = img["caption"]
cat_name = " ".join([caption[t[0] : t[1]] for t in ann["tokens_positive"]])
# 如果类别名称不在 cat2id 中,为其分配一个新的类别 ID,并将其添加到文本列表中
if cat_name not in cat2id:
cat2id[cat_name] = len(cat2id)
texts.append([cat_name])
cls = cat2id[cat_name] # class
# 将类别 ID 添加到边界框信息中。
box = [cls] + box.tolist()
# 如果边界框未重复,则将其添加到边界框列表中
if box not in bboxes:
bboxes.append(box)
# 将边界框列表转换为 NumPy 数组
lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)
# 构建标签字典,包含图像路径、形状、类别、边界框、归一化信息、边界框格式和文本信息
labels.append(
{
"im_file": im_file,
"shape": (h, w),
"cls": lb[:, 0:1], # n, 1
"bboxes": lb[:, 1:], # n, 4
"normalized": True,
"bbox_format": "xywh",
"texts": texts,
}
)
return labels
# 调用父类的 build_transforms 方法,获取基础的数据增强变换
def build_transforms(self, hyp=None):
"""Configures augmentations for training with optional text loading; `hyp` adjusts augmentation intensity."""
transforms = super().build_transforms(hyp)
# 增加文本增强
# 如果启用数据增强,插入 RandomLoadText 变换。
if self.augment:
# NOTE: hard-coded the args for now.
# max_samples=80:限制最大文本样本数量为 80。
# padding=True:对文本进行填充。
transforms.insert(-1, RandomLoadText(max_samples=80, padding=True))
return transforms
# 继承自 ConcatDataset,用于组合多个数据集。它提供了一个静态方法 collate_fn,用于将数据样本批量化。
class YOLOConcatDataset(ConcatDataset):
"""
Dataset as a concatenation of multiple datasets.
This class is useful to assemble different existing datasets.
"""
# 静态方法
# collate_fn 调用 YOLODataset 的 collate_fn 方法,用于将数据样本批量化。
@staticmethod
def collate_fn(batch):
"""Collates data samples into batches."""
return YOLODataset.collate_fn(batch)
# TODO: support semantic segmentation
class SemanticDataset(BaseDataset):
"""
Semantic Segmentation Dataset.
This class is responsible for handling datasets used for semantic segmentation tasks. It inherits functionalities
from the BaseDataset class.
Note:
This class is currently a placeholder and needs to be populated with methods and attributes for supporting
semantic segmentation tasks.
"""
def __init__(self):
"""Initialize a SemanticDataset object."""
super().__init__()
# 扩展了 torchvision.datasets.ImageFolder,用于支持 YOLO 分类任务。
# 它提供了图像增强、缓存和验证功能,支持在 RAM 或磁盘上缓存图像,以减少训练过程中的 IO 开销。
class ClassificationDataset:
"""
Extends torchvision ImageFolder to support YOLO classification tasks, offering functionalities like image
augmentation, caching, and verification. It's designed to efficiently handle large datasets for training deep
learning models, with optional image transformations and caching mechanisms to speed up training.
This class allows for augmentations using both torchvision and Albumentations libraries, and supports caching images
in RAM or on disk to reduce IO overhead during training. Additionally, it implements a robust verification process
to ensure data integrity and consistency.
Attributes:
cache_ram (bool): Indicates if caching in RAM is enabled.
cache_disk (bool): Indicates if caching on disk is enabled.
samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache
file (if caching on disk), and optionally the loaded image array (if caching in RAM).
torch_transforms (callable): PyTorch transforms to be applied to the images.
"""
def __init__(self, root, args, augment=False, prefix=""):
"""
Initialize YOLO object with root, image size, augmentations, and cache settings.
Args:
root (str): Path to the dataset directory where images are stored in a class-specific folder structure.
args (Namespace): Configuration containing dataset-related settings such as image size, augmentation
parameters, and cache settings. It includes attributes like `imgsz` (image size), `fraction` (fraction
of data to use), `scale`, `fliplr`, `flipud`, `cache` (disk or RAM caching for faster training),
`auto_augment`, `hsv_h`, `hsv_s`, `hsv_v`, and `crop_fraction`.
augment (bool, optional): Whether to apply augmentations to the dataset. Default is False.
prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification and
debugging. Default is an empty string.
"""
import torchvision # scope for faster 'import ultralytics'
# Base class assigned as attribute rather than used as base class to allow for scoping slow torchvision import
# 根据 torchvision 的版本,初始化 ImageFolder 数据集
if TORCHVISION_0_18: # 'allow_empty' argument first introduced in torchvision 0.18
self.base = torchvision.datasets.ImageFolder(root=root, allow_empty=True)
else:
self.base = torchvision.datasets.ImageFolder(root=root)
# 获取数据集的样本列表和根目录路径
self.samples = self.base.samples
self.root = self.base.root
# Initialize attributes
# 如果启用增强且数据集比例小于 1.0,减少训练数据的比例
if augment and args.fraction < 1.0: # reduce training fraction
self.samples = self.samples[: round(len(self.samples) * args.fraction)]
# 初始化日志前缀
self.prefix = colorstr(f"{prefix}: ") if prefix else ""
# 检查是否启用 RAM 缓存,如果启用,记录警告信息并禁用 RAM 缓存。
self.cache_ram = args.cache is True or str(args.cache).lower() == "ram" # cache images into RAM
if self.cache_ram:
LOGGER.warning(
"WARNING ⚠️ Classification `cache_ram` training has known memory leak in "
"https://github.com/ultralytics/ultralytics/issues/9824, setting `cache_ram=False`."
)
self.cache_ram = False
# 检查是否启用磁盘缓存
self.cache_disk = str(args.cache).lower() == "disk" # cache images on hard drive as uncompressed *.npy files
# 调用 verify_images 方法验证图像文件。
self.samples = self.verify_images() # filter out bad images
# 为每个样本添加 .npy 文件路径和图像数组(如果启用 RAM 缓存)。
self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
# 定义图像增强或变换
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
# 如果启用增强,调用 classify_augmentations 定义增强变换。
self.torch_transforms = (
classify_augmentations(
size=args.imgsz,
scale=scale,
hflip=args.fliplr,
vflip=args.flipud,
erasing=args.erasing,
auto_augment=args.auto_augment,
hsv_h=args.hsv_h,
hsv_s=args.hsv_s,
hsv_v=args.hsv_v,
)
if augment
# 如果不启用增强,调用 classify_transforms 定义基本变换。
else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction)
)
# 返回数据集中的一个样本及其对应的类别标签。
def __getitem__(self, i):
"""Returns subset of data and targets corresponding to given indices."""
# 获取样本的文件路径、类别索引、.npy 文件路径和图像数组
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
# 如果启用 RAM 缓存且图像数组为空,加载图像并缓存到 RAM 中
if self.cache_ram:
if im is None: # Warning: two separate if statements required here, do not combine this with previous line
im = self.samples[i][3] = cv2.imread(f)
# 如果启用磁盘缓存且 .npy 文件不存在,加载图像并保存为 .npy 文件。
elif self.cache_disk:
if not fn.exists(): # load npy
np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False)
# 加载 .npy 文件中的图像。
im = np.load(fn)
# 如果不启用缓存,直接读取图像。
else: # read image
im = cv2.imread(f) # BGR
# Convert NumPy array to PIL image
# 将 NumPy 数组转换为 PIL 图像: 将图像从 BGR 格式转换为 RGB 格式,并转换为 PIL 图像
im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
# 应用定义的 PyTorch 变换。
sample = self.torch_transforms(im)
return {"img": sample, "cls": j}
def __len__(self) -> int:
"""Return the total number of samples in the dataset."""
return len(self.samples)
# 验证数据集中的所有图像,检查图像文件的完整性和一致性。
def verify_images(self):
"""Verify all images in dataset."""
# 定义一个描述信息,用于日志记录和进度条显示
desc = f"{self.prefix}Scanning {self.root}..."
# 定义缓存文件路径:将数据集根目录的扩展名替换为 .cache,生成缓存文件路径
path = Path(self.root).with_suffix(".cache") # *.cache file path
# 尝试加载缓存文件,验证其版本和哈希值是否匹配
try:
# 调用 load_dataset_cache_file 加载缓存文件
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
# 检查缓存文件的版本是否与当前版本一致
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
# 检查缓存文件的哈希值是否与当前数据集的哈希值一致
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
# 如果缓存文件有效,提取验证结果(nf:找到的图像数,nc:损坏的图像数,n:总图像数,samples:有效样本列表)。
nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
# 如果当前进程是主进程(LOCAL_RANK in {-1, 0}),显示验证结果并记录警告信息。
if LOCAL_RANK in {-1, 0}:
d = f"{desc} {nf} images, {nc} corrupt"
TQDM(None, desc=d, total=n, initial=n)
if cache["msgs"]:
LOGGER.info("\n".join(cache["msgs"])) # display warnings
# 返回有效样本列表
return samples
# 如果缓存文件加载失败,运行扫描
except (FileNotFoundError, AssertionError, AttributeError):
# Run scan if *.cache retrieval failed
# 初始化统计变量:nf(找到的图像数),nc(损坏的图像数),msgs(警告信息列表),samples(有效样本列表),x(缓存字典)。
nf, nc, msgs, samples, x = 0, 0, [], [], {}
# 使用多线程池验证每个图像文件
with ThreadPool(NUM_THREADS) as pool:
# 使用 imap 方法并行调用 verify_image 函数,验证每个图像文件
results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
pbar = TQDM(results, desc=desc, total=len(self.samples))
# 遍历验证结果
for sample, nf_f, nc_f, msg in pbar:
# 如果图像有效,将其添加到 samples 列表
if nf_f:
samples.append(sample)
# 如果有警告信息,将其添加到 msgs 列表
if msg:
msgs.append(msg)
# 更新统计变量 nf 和 nc
nf += nf_f
nc += nc_f
# 更新进度条描述信息
pbar.desc = f"{desc} {nf} images, {nc} corrupt"
pbar.close()
# 如果存在警告信息,记录并显示它们
if msgs:
LOGGER.info("\n".join(msgs))
# 保存验证结果到缓存文件
# 计算当前数据集的哈希值
x["hash"] = get_hash([x[0] for x in self.samples])
# 将验证结果(nf、nc、len(samples)、samples)和警告信息(msgs)保存到缓存字典 x。
x["results"] = nf, nc, len(samples), samples
x["msgs"] = msgs # warnings
# 保存缓存文件
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
return samples
ultralytics/data/loaders.py 代码阅读
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import glob
import math
import os
import time
from dataclasses import dataclass
from pathlib import Path
from threading import Thread
from urllib.parse import urlparse
import cv2
import numpy as np
import requests
import torch
from PIL import Image
from ultralytics.data.utils import FORMATS_HELP_MSG, IMG_FORMATS, VID_FORMATS
from ultralytics.utils import IS_COLAB, IS_KAGGLE, LOGGER, ops
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.patches import imread
# 使用 @dataclass 装饰器定义一个数据类。用于表示不同类型的输入源(如视频流、截图、图像文件等)。
@dataclass
class SourceTypes:
"""
Class to represent various types of input sources for predictions.
This class uses dataclass to define boolean flags for different types of input sources that can be used for
making predictions with YOLO models.
Attributes:
stream (bool): Flag indicating if the input source is a video stream.
screenshot (bool): Flag indicating if the input source is a screenshot.
from_img (bool): Flag indicating if the input source is an image file.
Examples:
>>> source_types = SourceTypes(stream=True, screenshot=False, from_img=False)
>>> print(source_types.stream)
True
>>> print(source_types.from_img)
False
"""
# 定义布尔标志,表示输入源的类型。默认值为 False。
stream: bool = False
screenshot: bool = False
from_img: bool = False
tensor: bool = False
# 加载和处理多种类型的视频流,支持 RTSP、RTMP、HTTP 和 TCP 流。该类可以同时处理多个视频流,适用于实时视频分析任务。
class LoadStreams:
"""
Stream Loader for various types of video streams.
Supports RTSP, RTMP, HTTP, and TCP streams. This class handles the loading and processing of multiple video
streams simultaneously, making it suitable for real-time video analysis tasks.
Attributes:
sources (List[str]): The source input paths or URLs for the video streams.
vid_stride (int): Video frame-rate stride.
buffer (bool): Whether to buffer input streams.
running (bool): Flag to indicate if the streaming thread is running.
mode (str): Set to 'stream' indicating real-time capture.
imgs (List[List[np.ndarray]]): List of image frames for each stream.
fps (List[float]): List of FPS for each stream.
frames (List[int]): List of total frames for each stream.
threads (List[Thread]): List of threads for each stream.
shape (List[Tuple[int, int, int]]): List of shapes for each stream.
caps (List[cv2.VideoCapture]): List of cv2.VideoCapture objects for each stream.
bs (int): Batch size for processing.
Methods:
update: Read stream frames in daemon thread.
close: Close stream loader and release resources.
__iter__: Returns an iterator object for the class.
__next__: Returns source paths, transformed, and original images for processing.
__len__: Return the length of the sources object.
Examples:
>>> stream_loader = LoadStreams("rtsp://example.com/stream1.mp4")
>>> for sources, imgs, _ in stream_loader:
... # Process the images
... pass
>>> stream_loader.close()
Notes:
- The class uses threading to efficiently load frames from multiple streams simultaneously.
- It automatically handles YouTube links, converting them to the best available stream URL.
- The class implements a buffer system to manage frame storage and retrieval.
"""
def __init__(self, sources="file.streams", vid_stride=1, buffer=False):
"""Initialize stream loader for multiple video sources, supporting various stream types."""
# 设置 torch.backends.cudnn.benchmark 以加速固定大小的推理
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
# 初始化布尔标志 buffer 和 running
self.buffer = buffer # buffer input streams
self.running = True # running flag for Thread
# 设置模式为 stream
self.mode = "stream"
# 初始化视频帧率步长 vid_stride
self.vid_stride = vid_stride # video frame-rate stride
# 如果 sources 是文件路径,读取文件内容并分割为多个源
# 否则,将 sources 转换为列表
sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]
n = len(sources)
# 初始化批量大小 bs
self.bs = n
# 初始化帧率、帧数、线程、视频捕获对象、图像列表和形状列表。
self.fps = [0] * n # frames per second
self.frames = [0] * n
self.threads = [None] * n
self.caps = [None] * n # video capture objects
self.imgs = [[] for _ in range(n)] # images
self.shape = [[] for _ in range(n)] # image shapes
# 清理输入源名称
self.sources = [ops.clean_str(x) for x in sources] # clean source names for later
# 处理每个输入源
# 遍历每个输入源
for i, s in enumerate(sources): # index, source
# Start thread to read frames from video stream
st = f"{i + 1}/{n}: {s}... "
# 如果输入源是 YouTube 视频,获取最佳流 URL
if urlparse(s).hostname in {"www.youtube.com", "youtube.com", "youtu.be"}: # if source is YouTube video
# YouTube format i.e. 'https://www.youtube.com/watch?v=Jsn8D3aC840' or 'https://youtu.be/Jsn8D3aC840'
s = get_best_youtube_url(s)
# 如果输入源是数字(如 0 表示本地摄像头),使用 eval 转换
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
# 如果输入源是 0 且在 Colab 或 Kaggle 环境中,抛出 NotImplementedError。
if s == 0 and (IS_COLAB or IS_KAGGLE):
raise NotImplementedError(
"'source=0' webcam not supported in Colab and Kaggle notebooks. "
"Try running 'source=0' in a local environment."
)
# 使用 cv2.VideoCapture 打开视频流
self.caps[i] = cv2.VideoCapture(s) # store video capture object
if not self.caps[i].isOpened():
raise ConnectionError(f"{st}Failed to open {s}")
# 获取视频流的宽度、高度和帧率
w = int(self.caps[i].get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(self.caps[i].get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = self.caps[i].get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
self.frames[i] = max(int(self.caps[i].get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float(
"inf"
) # infinite stream fallback
# 如果帧率无效,使用 30 FPS 作为默认值
self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
# 读取第一帧,确保视频流可以读取。
success, im = self.caps[i].read() # guarantee first frame
# 如果读取失败,抛出 ConnectionError
if not success or im is None:
raise ConnectionError(f"{st}Failed to read images from {s}")
# 将第一帧添加到图像列表中
self.imgs[i].append(im)
# 保存图像形状
self.shape[i] = im.shape
# 创建并启动线程,用于在后台读取视频流
self.threads[i] = Thread(target=self.update, args=([i, self.caps[i], s]), daemon=True)
# 记录成功信息
LOGGER.info(f"{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)")
self.threads[i].start()
LOGGER.info("") # newline
# 在后台线程中读取视频流的帧,并更新图像缓冲区
def update(self, i, cap, stream):
"""Read stream frames in daemon thread and update image buffer."""
# 初始化帧数 n 和帧数组 f
n, f = 0, self.frames[i] # frame number, frame array
# 如果线程正在运行、视频流已打开且帧数未达到总帧数,则继续读取帧
while self.running and cap.isOpened() and n < (f - 1):
# 限制缓冲区大小为最多 30 帧
if len(self.imgs[i]) < 30: # keep a <=30-image buffer
# 增加帧数 n
n += 1
# 使用 cap.grab() 和 cap.retrieve() 读取帧
cap.grab() # .read() = .grab() followed by .retrieve()
if n % self.vid_stride == 0:
success, im = cap.retrieve()
# 如果读取失败,记录警告信息并重新打开视频流
if not success:
im = np.zeros(self.shape[i], dtype=np.uint8)
LOGGER.warning("WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.")
cap.open(stream) # re-open stream if signal was lost
# 如果启用缓冲区,将帧添加到缓冲区;否则,仅保留最新帧
if self.buffer:
self.imgs[i].append(im)
else:
self.imgs[i] = [im]
else:
# 如果缓冲区已满, 等待缓冲区清空
time.sleep(0.01) # wait until the buffer is empty
# 终止视频流加载器,停止线程并释放视频捕获资源
def close(self):
"""Terminates stream loader, stops threads, and releases video capture resources."""
# 设置 running 标志为 False,停止线程
self.running = False # stop flag for Thread
# 遍历线程列表,等待每个线程结束,设置超时时间为 5 秒
for thread in self.threads:
if thread.is_alive():
thread.join(timeout=5) # Add timeout
# 遍历视频捕获对象列表,释放每个对象
for cap in self.caps: # Iterate through the stored VideoCapture objects
try:
cap.release() # release video capture
# 如果释放失败,记录警告信息
except Exception as e:
LOGGER.warning(f"WARNING ⚠️ Could not release VideoCapture object: {e}")
# 关闭所有 OpenCV 窗口
cv2.destroyAllWindows()
# 返回一个迭代器对象,用于遍历 YOLO 图像流
def __iter__(self):
"""Iterates through YOLO image feed and re-opens unresponsive streams."""
# 初始化计数器 count
self.count = -1
return self
# 返回多个视频流的下一帧,用于处理
def __next__(self):
"""Returns the next batch of frames from multiple video streams for processing."""
# 增加计数器 count
self.count += 1
images = [] # 初始化图像列表
# 遍历每个视频流的图像缓冲区
for i, x in enumerate(self.imgs):
# Wait until a frame is available in each buffer
# 如果缓冲区为空,等待帧可用。
while not x:
# 如果线程已停止或按下 q 键,关闭流加载器并抛出 StopIteration
if not self.threads[i].is_alive() or cv2.waitKey(1) == ord("q"): # q to quit
self.close()
raise StopIteration
# 稍作等待,避免过快轮询
time.sleep(1 / min(self.fps))
x = self.imgs[i]
if not x:
LOGGER.warning(f"WARNING ⚠️ Waiting for stream {i}")
# Get and remove the first frame from imgs buffer
# 如果启用缓冲区,从缓冲区中移除并返回第一帧
if self.buffer:
images.append(x.pop(0))
# Get the last frame, and clear the rest from the imgs buffer
# 如果不启用缓冲区,返回最后一帧并清空缓冲区。
else:
images.append(x.pop(-1) if x else np.zeros(self.shape[i], dtype=np.uint8))
x.clear()
# 返回输入源路径、图像帧和空字符串列表
return self.sources, images, [""] * self.bs
def __len__(self):
"""Return the number of video streams in the LoadStreams object."""
# 返回批量大小 bs,即视频流的数量
return self.bs # 1E12 frames = 32 streams at 30 FPS for 30 years
# 用于捕获和处理屏幕截图,适用于实时屏幕捕获任务,例如与 YOLO 模型结合使用时的 yolo predict source=screen
class LoadScreenshots:
"""
Ultralytics screenshot dataloader for capturing and processing screen images.
This class manages the loading of screenshot images for processing with YOLO. It is suitable for use with
`yolo predict source=screen`.
Attributes:
source (str): The source input indicating which screen to capture.
screen (int): The screen number to capture.
left (int): The left coordinate for screen capture area.
top (int): The top coordinate for screen capture area.
width (int): The width of the screen capture area.
height (int): The height of the screen capture area.
mode (str): Set to 'stream' indicating real-time capture.
frame (int): Counter for captured frames.
sct (mss.mss): Screen capture object from `mss` library.
bs (int): Batch size, set to 1.
fps (int): Frames per second, set to 30.
monitor (Dict[str, int]): Monitor configuration details.
Methods:
__iter__: Returns an iterator object.
__next__: Captures the next screenshot and returns it.
Examples:
>>> loader = LoadScreenshots("0 100 100 640 480") # screen 0, top-left (100,100), 640x480
>>> for source, im, im0s, vid_cap, s in loader:
... print(f"Captured frame: {im.shape}")
"""
def __init__(self, source):
"""Initialize screenshot capture with specified screen and region parameters."""
# 检查是否安装了 mss 库,如果没有安装,则尝试安装。
# 导入 mss 库。
check_requirements("mss")
# mss 库(Multiple Screen Shots)是一个用于屏幕截图的高性能 Python 库,它提供了跨平台的屏幕捕获功能
# 主要功能
# 屏幕截图:可以捕获全屏或指定区域的屏幕截图。
# 跨平台支持:在 Windows、macOS 和 Linux 上都能高效运行。
# 高性能:使用 ctypes 模块直接调用操作系统的底层 API,确保截图操作快速且高效。
# 与 Python 生态系统集成:与 NumPy 和 OpenCV 等库集成良好,便于在图像处理和计算机视觉任务中使用。
# 简单易用:提供简洁的 API,易于上手和集成。
import mss # noqa
# 将输入参数 source 按空格分割
source, *params = source.split()
# 初始化屏幕编号和捕获区域的坐标和尺寸。
self.screen, left, top, width, height = 0, None, None, None, None # default to full screen 0
# 根据参数数量,设置屏幕编号和捕获区域的坐标和尺寸
if len(params) == 1:
self.screen = int(params[0])
elif len(params) == 4:
left, top, width, height = (int(x) for x in params)
elif len(params) == 5:
self.screen, left, top, width, height = (int(x) for x in params)
# 设置模式为 stream,表示实时捕获。
self.mode = "stream"
# 初始化帧计数器 frame
self.frame = 0
# 初始化屏幕捕获对象 sct
self.sct = mss.mss()
# 设置批量大小 bs 为 1
self.bs = 1
# 设置帧率为 30 FPS
self.fps = 30
# Parse monitor shape
# 获取指定屏幕的监视器配置
monitor = self.sct.monitors[self.screen]
# 设置捕获区域的坐标和尺寸,如果未指定,则使用默认值
self.top = monitor["top"] if top is None else (monitor["top"] + top)
self.left = monitor["left"] if left is None else (monitor["left"] + left)
self.width = width or monitor["width"]
self.height = height or monitor["height"]
self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height}
# 返回一个迭代器对象,用于遍历屏幕截图
def __iter__(self):
"""Yields the next screenshot image from the specified screen or region for processing."""
return self
# 捕获并返回下一帧屏幕截图
def __next__(self):
"""Captures and returns the next screenshot as a numpy array using the mss library."""
# 使用 mss 库捕获指定区域的屏幕截图。 self.sct = mss.mss()
# 将捕获的图像从 BGRA 格式转换为 BGR 格式
im0 = np.asarray(self.sct.grab(self.monitor))[:, :, :3] # BGRA to BGR
# 生成日志信息,包含屏幕编号和捕获区域的坐标和尺寸
s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "
# 增加帧计数器
self.frame += 1
# 返回屏幕编号、捕获的图像和日志信息
return [str(self.screen)], [im0], [s] # screen, img, string
# 用于加载和处理图像和视频数据,支持从单个文件、文件夹或文本文件中读取路径
class LoadImagesAndVideos:
"""
A class for loading and processing images and videos for YOLO object detection.
This class manages the loading and pre-processing of image and video data from various sources, including
single image files, video files, and lists of image and video paths.
Attributes:
files (List[str]): List of image and video file paths.
nf (int): Total number of files (images and videos).
video_flag (List[bool]): Flags indicating whether a file is a video (True) or an image (False).
mode (str): Current mode, 'image' or 'video'.
vid_stride (int): Stride for video frame-rate.
bs (int): Batch size.
cap (cv2.VideoCapture): Video capture object for OpenCV.
frame (int): Frame counter for video.
frames (int): Total number of frames in the video.
count (int): Counter for iteration, initialized at 0 during __iter__().
ni (int): Number of images.
Methods:
__init__: Initialize the LoadImagesAndVideos object.
__iter__: Returns an iterator object for VideoStream or ImageFolder.
__next__: Returns the next batch of images or video frames along with their paths and metadata.
_new_video: Creates a new video capture object for the given path.
__len__: Returns the number of batches in the object.
Examples:
>>> loader = LoadImagesAndVideos("path/to/data", batch=32, vid_stride=1)
>>> for paths, imgs, info in loader:
... # Process batch of images or video frames
... pass
Notes:
- Supports various image formats including HEIC.
- Handles both local files and directories.
- Can read from a text file containing paths to images and videos.
"""
def __init__(self, path, batch=1, vid_stride=1):
"""Initialize dataloader for images and videos, supporting various input formats."""
# 如果输入路径是 .txt 文件,读取文件内容并解析为路径列表。
# 保存 .txt 文件的父目录路径。
parent = None
if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
parent = Path(path).parent
path = Path(path).read_text().splitlines() # list of sources
files = []
# 遍历输入路径,支持通配符、文件夹和单个文件。
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
a = str(Path(p).absolute()) # do not use .resolve() https://github.com/ultralytics/ultralytics/issues/2912
# 如果路径包含通配符,使用 glob 匹配文件
if "*" in a:
files.extend(sorted(glob.glob(a, recursive=True))) # glob
# 如果路径是文件夹,获取文件夹中的所有文件。
elif os.path.isdir(a):
files.extend(sorted(glob.glob(os.path.join(a, "*.*")))) # dir
# 如果路径是文件,直接添加到文件列表
elif os.path.isfile(a):
files.append(a) # files (absolute or relative to CWD)
# 如果路径是 .txt 文件中的相对路径,解析为绝对路径
elif parent and (parent / p).is_file():
files.append(str((parent / p).absolute())) # files (relative to *.txt file parent)
else:
raise FileNotFoundError(f"{p} does not exist")
# Define files as images or videos
# 将文件分类为图像或视频
images, videos = [], []
# 支持的图像格式存储在 IMG_FORMATS 中,视频格式存储在 VID_FORMATS 中
for f in files:
suffix = f.split(".")[-1].lower() # Get file extension without the dot and lowercase
if suffix in IMG_FORMATS:
images.append(f)
elif suffix in VID_FORMATS:
videos.append(f)
ni, nv = len(images), len(videos)
# 初始化文件列表、文件数量、图像数量、视频标志、模式、帧率步长和批量大小
self.files = images + videos
self.nf = ni + nv # number of files
self.ni = ni # number of images
self.video_flag = [False] * ni + [True] * nv
self.mode = "video" if ni == 0 else "image" # default to video if no images
self.vid_stride = vid_stride # video frame-rate stride
self.bs = batch
# 如果有视频文件,初始化视频捕获对象
if any(videos):
self._new_video(videos[0]) # new video
else:
self.cap = None
# 如果没有找到任何文件,抛出 FileNotFoundError
if self.nf == 0:
raise FileNotFoundError(f"No images or videos found in {p}. {FORMATS_HELP_MSG}")
# 返回一个迭代器对象,用于遍历图像和视频文件
def __iter__(self):
"""Iterates through image/video files, yielding source paths, images, and metadata."""
self.count = 0
# 返回当前对象作为迭代器
return self
# 返回下一帧图像或视频帧,以及它们的路径和元数据
def __next__(self):
"""Returns the next batch of images or video frames with their paths and metadata."""
# 初始化返回值:路径、图像和元数据列表。
paths, imgs, info = [], [], []
# 循环直到批量大小满足要求
while len(imgs) < self.bs:
# 如果文件列表结束,返回剩余的批量或抛出 StopIteration
if self.count >= self.nf: # end of file list
if imgs:
return paths, imgs, info # return last partial batch
else:
raise StopIteration
# 获取当前文件路径
path = self.files[self.count]
# 处理视频文件
if self.video_flag[self.count]:
# 如果当前文件是视频,检查视频捕获对象是否打开
self.mode = "video"
if not self.cap or not self.cap.isOpened():
self._new_video(path) # 如果未打开,调用 _new_video 方法初始化视频捕获对象
# 按帧率步长读取视频帧。
success = False
for _ in range(self.vid_stride):
success = self.cap.grab()
if not success:
break # end of video or failure
# 如果帧读取成功,增加帧计数器。
if success:
success, im0 = self.cap.retrieve()
# 将路径、图像和元数据添加到返回列
if success:
self.frame += 1
paths.append(path)
imgs.append(im0)
info.append(f"video {self.count + 1}/{self.nf} (frame {self.frame}/{self.frames}) {path}: ")
# 如果视频结束,释放视频捕获对象并移动到下一个文件
if self.frame == self.frames: # end of video
self.count += 1
self.cap.release()
# 如果帧读取失败,释放视频捕获对象并移动到下一个文件
else:
# Move to the next file if the current video ended or failed to open
self.count += 1
if self.cap:
self.cap.release()
if self.count < self.nf:
self._new_video(self.files[self.count])
else:
# Handle image files (including HEIC)
self.mode = "image"
# 如果当前文件是图像,检查是否是 HEIC 格式
# 如果是 HEIC 格式,使用 pillow-heif 加载图像
if path.split(".")[-1].lower() == "heic":
# Load HEIC image using Pillow with pillow-heif
check_requirements("pillow-heif")
from pillow_heif import register_heif_opener
register_heif_opener() # Register HEIF opener with Pillow
with Image.open(path) as img:
im0 = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR) # convert image to BGR nparray
else:
im0 = imread(path) # BGR
# 如果图像加载失败,记录警告信息
if im0 is None:
LOGGER.warning(f"WARNING ⚠️ Image Read Error {path}")
# 如果图像加载成功,将路径、图像和元数据添加到返回列表
else:
paths.append(path)
imgs.append(im0)
info.append(f"image {self.count + 1}/{self.nf} {path}: ")
self.count += 1 # move to the next file
if self.count >= self.ni: # end of image list
break
return paths, imgs, info
# 为给定路径创建一个新的视频捕获对象,并初始化视频相关属性。
def _new_video(self, path):
"""Creates a new video capture object for the given path and initializes video-related attributes."""
# 初始化帧计数器。
self.frame = 0
# 使用 OpenCV 创建视频捕获对象
self.cap = cv2.VideoCapture(path)
# 获取视频的帧率
self.fps = int(self.cap.get(cv2.CAP_PROP_FPS))
# 如果视频未打开,抛出 FileNotFoundError
if not self.cap.isOpened():
raise FileNotFoundError(f"Failed to open video {path}")
# 计算总帧数,考虑帧率步长
self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
# 返回数据集中的文件数量(图像和视频)
def __len__(self):
"""Returns the number of files (images and videos) in the dataset."""
return math.ceil(self.nf / self.bs) # number of batches
# 用于从 PIL 和 NumPy 数组中加载图像数据,支持批量处理。
class LoadPilAndNumpy:
"""
Load images from PIL and Numpy arrays for batch processing.
This class manages loading and pre-processing of image data from both PIL and Numpy formats. It performs basic
validation and format conversion to ensure that the images are in the required format for downstream processing.
Attributes:
paths (List[str]): List of image paths or autogenerated filenames.
im0 (List[np.ndarray]): List of images stored as Numpy arrays.
mode (str): Type of data being processed, set to 'image'.
bs (int): Batch size, equivalent to the length of `im0`.
Methods:
_single_check: Validate and format a single image to a Numpy array.
Examples:
>>> from PIL import Image
>>> import numpy as np
>>> pil_img = Image.new("RGB", (100, 100))
>>> np_img = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
>>> loader = LoadPilAndNumpy([pil_img, np_img])
>>> paths, images, _ = next(iter(loader))
>>> print(f"Loaded {len(images)} images")
Loaded 2 images
"""
def __init__(self, im0):
"""Initializes a loader for PIL and Numpy images, converting inputs to a standardized format."""
# 如果输入不是列表,将其转换为列表
if not isinstance(im0, list):
im0 = [im0]
# use `image{i}.jpg` when Image.filename returns an empty path.
# 初始化路径列表,使用图像的文件名或自动生成的文件名
self.paths = [getattr(im, "filename", "") or f"image{i}.jpg" for i, im in enumerate(im0)]
# 初始化图像列表,调用 _single_check 方法验证和格式化每个图像
self.im0 = [self._single_check(im) for im in im0]
# 设置模式为 image,批量大小为图像列表的长度
self.mode = "image"
self.bs = len(self.im0)
# 验证和格式化单个图像为 NumPy 数组,确保 RGB 顺序和连续内存
@staticmethod
def _single_check(im):
"""Validate and format an image to numpy array, ensuring RGB order and contiguous memory."""
# 确保输入是 PIL 图像或 NumPy 数组
assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}"
# 如果图像是 PIL 格式,确保其模式为 RGB
if isinstance(im, Image.Image):
# 将 PIL 图像转换为 NumPy 数组,并调整通道顺序为 BGR
if im.mode != "RGB":
im = im.convert("RGB")
im = np.asarray(im)[:, :, ::-1]
im = np.ascontiguousarray(im) # 确保内存连续
return im
# 返回加载的图像数量
def __len__(self):
"""Returns the length of the 'im0' attribute, representing the number of loaded images."""
return len(self.im0)
# 返回下一帧图像、路径和元数据
def __next__(self):
"""Returns the next batch of images, paths, and metadata for processing."""
if self.count == 1: # loop only once as it's batch inference
raise StopIteration
self.count += 1
return self.paths, self.im0, [""] * self.bs
# 返回一个迭代器对象,用于遍历图像
def __iter__(self):
"""Iterates through PIL/numpy images, yielding paths, raw images, and metadata for processing."""
self.count = 0
return self
# 用于加载和处理 PyTorch 张量数据,准备用于目标检测任务。
class LoadTensor:
"""
A class for loading and processing tensor data for object detection tasks.
This class handles the loading and pre-processing of image data from PyTorch tensors, preparing them for
further processing in object detection pipelines.
Attributes:
im0 (torch.Tensor): The input tensor containing the image(s) with shape (B, C, H, W).
bs (int): Batch size, inferred from the shape of `im0`.
mode (str): Current processing mode, set to 'image'.
paths (List[str]): List of image paths or auto-generated filenames.
Methods:
_single_check: Validates and formats an input tensor.
Examples:
>>> import torch
>>> tensor = torch.rand(1, 3, 640, 640)
>>> loader = LoadTensor(tensor)
>>> paths, images, info = next(iter(loader))
>>> print(f"Processed {len(images)} images")
"""
def __init__(self, im0) -> None:
"""Initialize LoadTensor object for processing torch.Tensor image data."""
# 调用 _single_check 方法验证和格式化输入张量。
self.im0 = self._single_check(im0)
self.bs = self.im0.shape[0] # 从张量的形状中获取批量大小。
self.mode = "image" # 设置当前处理模式为 image
# 生成路径列表,使用张量的 filename 属性或自动生成的文件名
self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)]
# 验证和格式化单个图像张量,确保其形状和归一化正确
@staticmethod
def _single_check(im, stride=32):
"""Validates and formats a single image tensor, ensuring correct shape and normalization."""
# 定义警告信息,提示输入张量应为 BCHW 格式且尺寸应能被步长整除。
s = (
f"WARNING ⚠️ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) "
f"divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible."
)
# 如果张量维度不是 4,检查是否为 3 维。如果是 3 维,添加批量维度
if len(im.shape) != 4:
if len(im.shape) != 3:
raise ValueError(s)
LOGGER.warning(s)
im = im.unsqueeze(0)
# 检查张量的高度和宽度是否能被步长整除
if im.shape[2] % stride or im.shape[3] % stride:
raise ValueError(s)
# 如果张量的最大值大于 1.0,提示警告并将其归一化到 [0, 1] 范围
if im.max() > 1.0 + torch.finfo(im.dtype).eps: # torch.float32 eps is 1.2e-07
LOGGER.warning(
f"WARNING ⚠️ torch.Tensor inputs should be normalized 0.0-1.0 but max value is {im.max()}. "
f"Dividing input by 255."
)
im = im.float() / 255.0
# 返回格式化后的张量
return im
# 返回一个迭代器对象,用于遍历张量图像数据。
def __iter__(self):
"""Yields an iterator object for iterating through tensor image data."""
self.count = 0
return self
# 返回下一帧张量图像和元数据
def __next__(self):
"""Yields the next batch of tensor images and metadata for processing."""
if self.count == 1:
raise StopIteration
self.count += 1
return self.paths, self.im0, [""] * self.bs
# 返回张量输入的批量大小
def __len__(self):
"""Returns the batch size of the tensor input."""
return self.bs
# 将输入源列表转换为 PIL 图像或 NumPy 数组列表,用于 Ultralytics 预测。
def autocast_list(source):
"""Merges a list of sources into a list of numpy arrays or PIL images for Ultralytics prediction."""
files = []
# 遍历输入源列表
for im in source:
# 如果输入是文件路径或 URI,使用 requests 获取内容并打开为 PIL 图像
if isinstance(im, (str, Path)): # filename or uri
files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith("http") else im))
# 如果输入是 PIL 图像或 NumPy 数组,直接添加到文件列表
elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image
files.append(im)
# 抛出不支持的类型错误
else:
raise TypeError(
f"type {type(im).__name__} is not a supported Ultralytics prediction source type. \n"
f"See https://docs.ultralytics.com/modes/predict for supported source types."
)
return files
# 从给定的 YouTube 视频中获取最佳质量的 MP4 视频流 URL
def get_best_youtube_url(url, method="pytube"):
"""
Retrieves the URL of the best quality MP4 video stream from a given YouTube video.
Args:
url (str): The URL of the YouTube video.
method (str): The method to use for extracting video info. Options are "pytube", "pafy", and "yt-dlp".
Defaults to "pytube".
Returns:
(str | None): The URL of the best quality MP4 video stream, or None if no suitable stream is found.
Examples:
>>> url = "https://www.youtube.com/watch?v=dQw4w9WgXcQ"
>>> best_url = get_best_youtube_url(url)
>>> print(best_url)
https://rr4---sn-q4flrnek.googlevideo.com/videoplayback?expire=...
Notes:
- Requires additional libraries based on the chosen method: pytubefix, pafy, or yt-dlp.
- The function prioritizes streams with at least 1080p resolution when available.
- For the "yt-dlp" method, it looks for formats with video codec, no audio, and *.mp4 extension.
"""
# 处理 pytube 方法
if method == "pytube":
# Switched from pytube to pytubefix to resolve https://github.com/pytube/pytube/issues/1954
# 确保安装了 pytubefix 库,版本至少为 6.5.2。
check_requirements("pytubefix>=6.5.2")
from pytubefix import YouTube
# 使用 YouTube 类获取视频的所有 MP4 格式、仅视频的流
streams = YouTube(url).streams.filter(file_extension="mp4", only_video=True)
# 按分辨率从高到低排序
streams = sorted(streams, key=lambda s: s.resolution, reverse=True) # sort streams by resolution
# 遍历流,选择分辨率至少为 1080p 的流。
for stream in streams:
if stream.resolution and int(stream.resolution[:-1]) >= 1080: # check if resolution is at least 1080p
# 返回该流的 URL
return stream.url
# 处理 pafy 方法
elif method == "pafy":
# 确保安装了 pafy 和 youtube_dl 库,youtube_dl 版本为 2020.12.2
check_requirements(("pafy", "youtube_dl==2020.12.2"))
import pafy # noqa
# 使用 pafy 获取最佳 MP4 格式的视频流
# 返回该流的 URL
return pafy.new(url).getbestvideo(preftype="mp4").url
# 处理 yt-dlp 方法
# 确保安装了 yt-dlp 库
elif method == "yt-dlp":
check_requirements("yt-dlp")
import yt_dlp
# 提取视频信息
with yt_dlp.YoutubeDL({"quiet": True}) as ydl:
info_dict = ydl.extract_info(url, download=False) # extract info
# 选择最佳流
# 遍历流,从后向前查找最佳流(通常最佳流在最后)
for f in reversed(info_dict.get("formats", [])): # reversed because best is usually last
# Find a format with video codec, no audio, *.mp4 extension at least 1920x1080 size
# 选择分辨率至少为 1920x1080 的流,且视频编码存在、音频编码不存在、格式为 MP4
good_size = (f.get("width") or 0) >= 1920 or (f.get("height") or 0) >= 1080
if good_size and f["vcodec"] != "none" and f["acodec"] == "none" and f["ext"] == "mp4":
return f.get("url")
# LOADERS 常量的作用是将这些加载器类集中管理,方便在代码中根据需要选择合适的加载器。
# 例如,可以在一个函数中根据输入数据的类型选择合适的加载器来加载数据。
# 它是一个元组,包含四个类:LoadStreams、LoadPilAndNumpy、LoadImagesAndVideos 和 LoadScreenshots。
# 这些类分别用于加载不同类型的数据源,例如视频流、PIL 和 NumPy 图像、图像和视频文件,以及屏幕截图
# Define constants
LOADERS = (LoadStreams, LoadPilAndNumpy, LoadImagesAndVideos, LoadScreenshots)
# 示例用法
# 假设你有一个函数 load_data,它根据输入数据的类型选择合适的加载器:
# def load_data(source):
# for loader in LOADERS:
# try:
# return loader(source)
# except Exception as e:
# print(f"Failed to load with {loader.__name__}: {e}")
# raise ValueError("No suitable loader found for the given source.")
# 在这个函数中,LOADERS 被用来遍历所有可能的加载器,直到找到一个能够成功加载数据的加载器。
# 以下是关于 yt-dlp、pafy 和 pytubefix 这三个库的功能说明:
# 1. yt-dlp
# yt-dlp 是一个功能强大的命令行工具,用于下载 YouTube 视频及其相关信息。它支持多种格式选择、过滤和下载选项,能够处理复杂的下载需求。
# 主要功能:
# 格式选择:支持下载最佳视频、音频流,或特定格式(如 MP4、m4a)的流。
# 过滤和排序:可以根据分辨率、大小、编解码器等条件过滤和排序视频流。
# 多流下载:支持下载多个音频流并合并到一个文件中。
# 字幕下载:可以下载视频的字幕文件(如 SRT 格式)。
# 高级过滤:支持复杂的过滤条件,如按协议、帧率、文件大小等。
# 2. pafy
# pafy 是一个用于处理 YouTube 视频的 Python 库,专注于获取视频的元数据和下载功能。
# 主要功能:
# 元数据获取:可以获取视频的标题、描述、观看次数、评分、缩略图等信息。
# 视频下载:支持下载视频和音频流,可以指定分辨率、格式等。
# 字幕支持:可以下载视频的字幕文件。
# 简单易用:提供了简洁的 API 接口,便于集成到其他项目中。
# 3. pytubefix
# pytubefix 是一个轻量级的 Python 库,用于下载 YouTube 视频,它是 pytube 的一个修复版本,解决了原版 pytube 中的一些问题。
# 主要功能:
# 视频下载:支持下载 YouTube 视频,包括最高分辨率的视频流。
# 音频下载:可以下载音频流(如 m4a 格式)。
# 字幕下载:支持下载字幕文件并保存为 SRT 格式。
# 无依赖:不依赖第三方库,保持轻量级。
# 回调支持:支持下载进度回调和完成回调。
ultralytics/data/split_dota.py 代码阅读
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import itertools
from glob import glob
from math import ceil
from pathlib import Path
import cv2
import numpy as np
from PIL import Image
from ultralytics.data.utils import exif_size, img2label_paths
from ultralytics.utils import TQDM
from ultralytics.utils.checks import check_requirements
# 用于计算多边形和边界框之间的交并比
def bbox_iof(polygon1, bbox2, eps=1e-6):
"""
Calculate Intersection over Foreground (IoF) between polygons and bounding boxes.
Args:
polygon1 (np.ndarray): Polygon coordinates, shape (n, 8).
bbox2 (np.ndarray): Bounding boxes, shape (n, 4).
eps (float, optional): Small value to prevent division by zero. Defaults to 1e-6.
Returns:
(np.ndarray): IoF scores, shape (n, 1) or (n, m) if bbox2 is (m, 4).
Note:
Polygon format: [x1, y1, x2, y2, x3, y3, x4, y4].
Bounding box format: [x_min, y_min, x_max, y_max].
"""
# 确保安装了 shapely 库,这是一个用于几何操作的 Python 库。
# 导入 shapely.geometry.Polygon,用于处理多边形。
check_requirements("shapely")
from shapely.geometry import Polygon
# 将多边形坐标从 (n, 8) 重塑为 (n, 4, 2),表示每个多边形的四个顶点。
polygon1 = polygon1.reshape(-1, 4, 2)
# 计算多边形的左上角和右下角点
lt_point = np.min(polygon1, axis=-2) # left-top
rb_point = np.max(polygon1, axis=-2) # right-bottom
# 将这些点组合成边界框,格式为 [x_min, y_min, x_max, y_max]
bbox1 = np.concatenate([lt_point, rb_point], axis=-1)
# 计算两个边界框的交集区域
# 使用 np.maximum 和 np.minimum 找到交集的左上角和右下角点
lt = np.maximum(bbox1[:, None, :2], bbox2[..., :2])
rb = np.minimum(bbox1[:, None, 2:], bbox2[..., 2:])
# 计算交集区域的宽度和高度
wh = np.clip(rb - lt, 0, np.inf)
# 计算交集区域的面积
h_overlaps = wh[..., 0] * wh[..., 1]
# 从边界框生成多边形的四个顶点
left, top, right, bottom = (bbox2[..., i] for i in range(4))
polygon2 = np.stack([left, top, right, top, right, bottom, left, bottom], axis=-1).reshape(-1, 4, 2)
# 将 NumPy 数组转换为 Shapely 多边形对象
sg_polys1 = [Polygon(p) for p in polygon1]
sg_polys2 = [Polygon(p) for p in polygon2]
# 初始化交集面积数组。
overlaps = np.zeros(h_overlaps.shape)
# 遍历所有可能的交集区域,计算实际交集面积
for p in zip(*np.nonzero(h_overlaps)):
overlaps[p] = sg_polys1[p[0]].intersection(sg_polys2[p[-1]]).area
# 计算并集面积, 计算每个多边形的面积
unions = np.array([p.area for p in sg_polys1], dtype=np.float32)
# 确保并集面积不为零,避免除零错误
unions = unions[..., None]
unions = np.clip(unions, eps, np.inf)
# 计算交并比(IoF)
outputs = overlaps / unions
# 如果结果是一维的,将其扩展为二维数组。
if outputs.ndim == 1:
outputs = outputs[..., None]
return outputs
# 加载 DOTA 数据集,支持 train 和 val 分割
def load_yolo_dota(data_root, split="train"):
"""
Load DOTA dataset.
Args:
data_root (str): Data root.
split (str): The split data set, could be `train` or `val`.
Notes:
The directory structure assumed for the DOTA dataset:
- data_root
- images
- train
- val
- labels
- train
- val
"""
# 确保 split 参数是 train 或 val
assert split in {"train", "val"}, f"Split must be 'train' or 'val', not {split}."
# 构造图像目录路径
im_dir = Path(data_root) / "images" / split
# 确保路径存在。
assert im_dir.exists(), f"Can't find {im_dir}, please check your data root."
# 获取图像文件路径
im_files = glob(str(Path(data_root) / "images" / split / "*"))
# 使用 img2label_paths 函数将图像路径转换为标签路径。
lb_files = img2label_paths(im_files)
# 遍历每个图像文件和对应的标签文件。
annos = []
for im_file, lb_file in zip(im_files, lb_files):
# 使用 exif_size 获取图像的宽度和高度。
w, h = exif_size(Image.open(im_file))
# 读取标签文件内容,解析为 NumPy 数组。
with open(lb_file) as f:
lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
lb = np.array(lb, dtype=np.float32)
# 将标注信息存储为字典,包括原始尺寸、标签和文件路径。
annos.append(dict(ori_size=(h, w), label=lb, filepath=im_file))
# 回标注信息列表
return annos
# 根据给定的图像大小、裁剪尺寸和间隔,计算窗口的坐标
# im_size:图像的原始大小,以元组形式表示(高度,宽度)。
# crop_sizes:窗口的裁剪尺寸,默认为 1024。
# gaps:窗口之间的间隔,默认为 200。
# im_rate_thr:窗口面积与图像面积的阈值,默认为 0.6。
# eps:用于数学运算的小值,默认为 0.01。
def get_windows(im_size, crop_sizes=(1024,), gaps=(200,), im_rate_thr=0.6, eps=0.01):
"""
Get the coordinates of windows.
Args:
im_size (tuple): Original image size, (h, w).
crop_sizes (List(int)): Crop size of windows.
gaps (List(int)): Gap between crops.
im_rate_thr (float): Threshold of windows areas divided by image ares.
eps (float): Epsilon value for math operations.
"""
# 将 im_size 解包为高度 h 和宽度 w。
h, w = im_size
# 初始化一个空列表 windows,用于存储窗口坐标。
windows = []
# 遍历每个裁剪尺寸和间隔
for crop_size, gap in zip(crop_sizes, gaps):
# 确保裁剪尺寸大于间隔,若不满足条件则抛出异常。
assert crop_size > gap, f"invalid crop_size gap pair [{crop_size} {gap}]"
# 计算步长 step,即裁剪尺寸减去间隔
step = crop_size - gap
# 计算水平窗口坐标
# 计算在给定宽度下可以放置的窗口数量 xn
xn = 1 if w <= crop_size else ceil((w - crop_size) / step + 1)
# 生成窗口的起始 x 坐标列表 xs
xs = [step * i for i in range(xn)]
# 如果最后一个窗口超出图像宽度,调整最后一个窗口的起始坐标
if len(xs) > 1 and xs[-1] + crop_size > w:
xs[-1] = w - crop_size
# 计算在给定高度下可以放置的窗口数量 yn
yn = 1 if h <= crop_size else ceil((h - crop_size) / step + 1)
# 生成窗口的起始 y 坐标列表 ys
ys = [step * i for i in range(yn)]
# 如果最后一个窗口超出图像高度,调整最后一个窗口的起始坐标
if len(ys) > 1 and ys[-1] + crop_size > h:
ys[-1] = h - crop_size
# 使用 itertools.product 生成所有可能的窗口起始坐标组合
start = np.array(list(itertools.product(xs, ys)), dtype=np.int64)
# 计算窗口的结束坐标 stop
stop = start + crop_size
# 将起始和结束坐标拼接并添加到 windows 列表中
windows.append(np.concatenate([start, stop], axis=1))
windows = np.concatenate(windows, axis=0)
# 复制窗口坐标到 im_in_wins
im_in_wins = windows.copy()
# 将窗口的坐标限制在图像的边界内
im_in_wins[:, 0::2] = np.clip(im_in_wins[:, 0::2], 0, w)
im_in_wins[:, 1::2] = np.clip(im_in_wins[:, 1::2], 0, h)
# 计算窗口和图像的面积
im_areas = (im_in_wins[:, 2] - im_in_wins[:, 0]) * (im_in_wins[:, 3] - im_in_wins[:, 1])
win_areas = (windows[:, 2] - windows[:, 0]) * (windows[:, 3] - windows[:, 1])
# 计算图像面积与窗口面积的比例 im_rates。
im_rates = im_areas / win_areas
# 如果没有窗口的面积比例超过阈值 im_rate_thr,则将接近最大比例的窗口比例设置为 1。
if not (im_rates > im_rate_thr).any():
max_rate = im_rates.max()
im_rates[abs(im_rates - max_rate) < eps] = 1
# 返回面积比例超过阈值的窗口坐标。
return windows[im_rates > im_rate_thr]
# 为每个窗口获取对应的对象信息。
def get_window_obj(anno, windows, iof_thr=0.7):
"""Get objects for each window."""
# 从标注信息中获取图像的原始高度和宽度。
h, w = anno["ori_size"]
label = anno["label"]
# 如果标签存在,将边界框的坐标从相对坐标转换为绝对坐标(乘以图像的宽度和高度)。
if len(label):
label[:, 1::2] *= w
label[:, 2::2] *= h
# 调用 bbox_iof 函数计算每个边界框与窗口的交集比(IoF)。
iofs = bbox_iof(label[:, 1:], windows)
# Unnormalized and misaligned coordinates
# 对于每个窗口,返回交集比大于等于阈值 iof_thr 的标签
return [(label[iofs[:, i] >= iof_thr]) for i in range(len(windows))] # window_anns
else:
# 如果没有标签,返回一个包含空数组的列表,表示每个窗口没有对象
return [np.zeros((0, 9), dtype=np.float32) for _ in range(len(windows))] # window_anns
# 裁剪图像并保存新的标签
def crop_and_save(anno, windows, window_objs, im_dir, lb_dir, allow_background_images=True):
"""
Crop images and save new labels.
Args:
anno (dict): Annotation dict, including `filepath`, `label`, `ori_size` as its keys.
windows (list): A list of windows coordinates.
window_objs (list): A list of labels inside each window.
im_dir (str): The output directory path of images.
lb_dir (str): The output directory path of labels.
allow_background_images (bool): Whether to include background images without labels.
Notes:
The directory structure assumed for the DOTA dataset:
- data_root
- images
- train
- val
- labels
- train
- val
"""
# 使用 OpenCV 读取指定路径的图像文件
im = cv2.imread(anno["filepath"])
# 从图像文件路径中提取文件名(不带扩展名)
name = Path(anno["filepath"]).stem
# 遍历每个窗口,获取窗口的起始和结束坐标。
for i, window in enumerate(windows):
x_start, y_start, x_stop, y_stop = window.tolist()
# 根据窗口的坐标生成新的文件名
new_name = f"{name}__{x_stop - x_start}__{x_start}___{y_start}"
# 使用坐标裁剪图像,得到窗口对应的图像部分
patch_im = im[y_start:y_stop, x_start:x_stop]
# 获取裁剪后图像的高度和宽度
ph, pw = patch_im.shape[:2]
# 获取当前窗口的标签信息。
label = window_objs[i]
# 如果标签存在或允许背景图像,则保存裁剪后的图像。
if len(label) or allow_background_images:
cv2.imwrite(str(Path(im_dir) / f"{new_name}.jpg"), patch_im)
# 如果标签存在,调整标签的坐标:
if len(label):
# 将坐标从窗口坐标系转换回相对坐标系
label[:, 1::2] -= x_start
label[:, 2::2] -= y_start
# 将坐标归一化到 [0, 1] 范围
label[:, 1::2] /= pw
label[:, 2::2] /= ph
# 打开标签文件并写入标签信息
with open(Path(lb_dir) / f"{new_name}.txt", "w") as f:
# 每个标签的格式为:class_id x_center y_center width height。
for lb in label:
formatted_coords = [f"{coord:.6g}" for coord in lb[1:]]
f.write(f"{int(lb[0])} {' '.join(formatted_coords)}\n")
# 将图像和标签数据集分割为指定的训练集或验证集,并保存到指定的目录结构中
# data_root:数据集根目录,包含图像和标签。
# save_dir:保存分割后图像和标签的目录。
# split:指定分割类型(train 或 val),默认为 train。
# crop_sizes:窗口的裁剪尺寸,默认为 1024。
# gaps:窗口之间的间隔,默认为 200。
def split_images_and_labels(data_root, save_dir, split="train", crop_sizes=(1024,), gaps=(200,)):
"""
Split both images and labels.
Notes:
The directory structure assumed for the DOTA dataset:
- data_root
- images
- split
- labels
- split
and the output directory structure is:
- save_dir
- images
- split
- labels
- split
"""
# 构造图像和标签的保存目录路径。
# 使用 mkdir 创建目录,确保父目录存在。
im_dir = Path(save_dir) / "images" / split
im_dir.mkdir(parents=True, exist_ok=True)
lb_dir = Path(save_dir) / "labels" / split
lb_dir.mkdir(parents=True, exist_ok=True)
# 调用 load_yolo_dota 函数加载指定分割的 DOTA 数据集的标注信息。
annos = load_yolo_dota(data_root, split=split)
# 遍历每个标注信息
# 使用 TQDM 显示进度条,遍历每个标注信息。
# 调用 get_windows 函数计算窗口坐标。
# 调用 get_window_obj 函数获取每个窗口的对象信息。
# 调用 crop_and_save 函数裁剪图像并保存新的标签。
for anno in TQDM(annos, total=len(annos), desc=split):
windows = get_windows(anno["ori_size"], crop_sizes, gaps)
window_objs = get_window_obj(anno, windows)
crop_and_save(anno, windows, window_objs, str(im_dir), str(lb_dir))
# 将 DOTA 数据集分割为训练集和验证集,并保存到指定的目录结构中。
# data_root:数据集根目录,包含图像和标签。
# save_dir:保存分割后图像和标签的目录。
# crop_size:裁剪窗口的大小,默认为 1024。
# gap:窗口之间的间隔,默认为 200。
# rates:训练集和验证集的比例,默认为 (1.0,)。
def split_trainval(data_root, save_dir, crop_size=1024, gap=200, rates=(1.0,)):
"""
Split train and val set of DOTA.
Notes:
The directory structure assumed for the DOTA dataset:
- data_root
- images
- train
- val
- labels
- train
- val
and the output directory structure is:
- save_dir
- images
- train
- val
- labels
- train
- val
"""
# 根据 rates 计算裁剪尺寸和间隔
crop_sizes, gaps = [], []
for r in rates:
crop_sizes.append(int(crop_size / r))
gaps.append(int(gap / r))
# 遍历训练集和验证集
# 对于每个分割(train 和 val),调用 split_images_and_labels 函数进行分割和保存。
for split in ["train", "val"]:
split_images_and_labels(data_root, save_dir, split, crop_sizes, gaps)
# 将 DOTA 数据集的测试集分割并保存到指定的目录结构中,测试集不包含标签。
def split_test(data_root, save_dir, crop_size=1024, gap=200, rates=(1.0,)):
"""
Split test set of DOTA, labels are not included within this set.
Notes:
The directory structure assumed for the DOTA dataset:
- data_root
- images
- test
and the output directory structure is:
- save_dir
- images
- test
"""
# 根据 rates 计算裁剪尺寸和间隔
crop_sizes, gaps = [], []
for r in rates:
crop_sizes.append(int(crop_size / r))
gaps.append(int(gap / r))
# 定义测试集的保存目录并创建。
save_dir = Path(save_dir) / "images" / "test"
save_dir.mkdir(parents=True, exist_ok=True)
# 获取测试集图像的路径。
im_dir = Path(data_root) / "images" / "test"
# 确保路径存在。
assert im_dir.exists(), f"Can't find {im_dir}, please check your data root."
im_files = glob(str(im_dir / "*"))
# 遍历测试集图像,获取每个图像的宽度和高度。
for im_file in TQDM(im_files, total=len(im_files), desc="test"):
# 计算窗口坐标
w, h = exif_size(Image.open(im_file))
windows = get_windows((h, w), crop_sizes=crop_sizes, gaps=gaps)
im = cv2.imread(im_file)
name = Path(im_file).stem
# 裁剪图像并保存到指定目录
for window in windows:
x_start, y_start, x_stop, y_stop = window.tolist()
new_name = f"{name}__{x_stop - x_start}__{x_start}___{y_start}"
patch_im = im[y_start:y_stop, x_start:x_stop]
cv2.imwrite(str(save_dir / f"{new_name}.jpg"), patch_im)
if __name__ == "__main__":
split_trainval(data_root="DOTAv2", save_dir="DOTAv2-split")
split_test(data_root="DOTAv2", save_dir="DOTAv2-split")
ultralytics/data/utils.py 代码阅读
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import hashlib
import json
import os
import random
import subprocess
import time
import zipfile
from multiprocessing.pool import ThreadPool
from pathlib import Path
from tarfile import is_tarfile
import cv2
import numpy as np
from PIL import Image, ImageOps
from ultralytics.nn.autobackend import check_class_names
from ultralytics.utils import (
DATASETS_DIR,
LOGGER,
NUM_THREADS,
ROOT,
SETTINGS_FILE,
TQDM,
clean_url,
colorstr,
emojis,
is_dir_writeable,
yaml_load,
yaml_save,
)
from ultralytics.utils.checks import check_file, check_font, is_ascii
from ultralytics.utils.downloads import download, safe_download, unzip_file
from ultralytics.utils.ops import segments2boxes
HELP_URL = "See https://docs.ultralytics.com/datasets for dataset formatting guidance."
IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm", "heic"} # image suffixes
VID_FORMATS = {"asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm"} # video suffixes
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
FORMATS_HELP_MSG = f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
# 根据图像路径生成对应的标签路径
def img2label_paths(img_paths):
"""Define label paths as a function of image paths."""
# sa 和 sb 分别表示图像和标签目录的子字符串(如 /images/ 和 /labels/)
sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings
# 对于每个图像路径 x,使用 rsplit 将其分割为两部分,替换 /images/ 为 /labels/,并去掉文件扩展名,添加 .txt 后缀,生成标签路径
return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
# 计算一组路径(文件或目录)的哈希值
def get_hash(paths):
"""Returns a single hash value of a list of paths (files or dirs)."""
# 计算所有路径的总大小(size)
size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
# 使用 SHA-256 哈希算法对总大小进行哈希
h = hashlib.sha256(str(size).encode()) # hash sizes
# 将所有路径的字符串连接起来,更新哈希值
h.update("".join(paths).encode()) # hash paths
# 返回最终的哈希值(十六进制字符串)
return h.hexdigest() # return hash
# 获取图像的实际大小,考虑 EXIF 信息中的旋转
def exif_size(img: Image.Image):
"""Returns exif-corrected PIL size."""
# 获取图像的原始大小 s(宽度和高度)
s = img.size # (width, height)
# 如果图像是 JPEG 格式,尝试读取 EXIF 信息中的旋转标签(键为 274)
if img.format == "JPEG": # only support JPEG images
try:
if exif := img.getexif():
rotation = exif.get(274, None) # the EXIF key for the orientation tag is 274
# 如果旋转角度为 90 或 270 度,交换宽度和高度
if rotation in {6, 8}: # rotation 270 or 90
s = s[1], s[0]
except Exception:
pass
# 返回修正后的大小
return s
# 验证单个图像是否有效
def verify_image(args):
"""Verify one image."""
(im_file, cls), prefix = args
# Number (found, corrupt), message
nf, nc, msg = 0, 0, ""
try:
# 打开图像文件并验证其完整性
im = Image.open(im_file)
im.verify() # PIL verify
# 获取图像大小并检查是否小于 10 像素
shape = exif_size(im) # image size
shape = (shape[1], shape[0]) # hw
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
# 检查图像格式是否在支持的格式列表中
assert im.format.lower() in IMG_FORMATS, f"Invalid image format {im.format}. {FORMATS_HELP_MSG}"
# 如果图像是 JPEG 格式,检查是否损坏(JPEG 文件末尾应为 \xff\xd9)。如果损坏,尝试修复并保存。
if im.format.lower() in {"jpg", "jpeg"}:
with open(im_file, "rb") as f:
f.seek(-2, 2)
if f.read() != b"\xff\xd9": # corrupt JPEG
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved"
nf = 1
except Exception as e:
nc = 1
msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
# 返回验证结果:
# (im_file, cls):图像文件路径和类别。
# nf:是否找到图像(1 表示找到,0 表示未找到)。
# nc:是否损坏(1 表示损坏,0 表示未损坏)。
# msg:验证消息。
return (im_file, cls), nf, nc, msg
# 验证图像及其对应的标签文件是否有效。
# args 是一个元组,包含以下内容:
# im_file:图像文件路径。
# lb_file:标签文件路径。
# prefix:日志前缀,用于打印警告信息。
# keypoint:布尔值,表示是否是关键点检测任务。
# num_cls:数据集中类别的总数。
# nkpt:每个对象的关键点数量(仅在关键点检测任务中使用)。
# ndim:关键点的维度(通常是 2 或 3)。
def verify_image_label(args):
"""Verify one image-label pair."""
im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args
# Number (missing, found, empty, corrupt), message, segments, keypoints
# nm:标记标签是否缺失(1 表示缺失,0 表示不缺失)。
# nf:标记标签是否找到(1 表示找到,0 表示未找到)。
# ne:标记标签是否为空(1 表示为空,0 表示不为空)。
# nc:标记图像或标签是否损坏(1 表示损坏,0 表示未损坏)。
# msg:用于存储验证过程中的警告信息。
# segments:用于存储多边形分割数据。
# keypoints:用于存储关键点数据。
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None
try:
# Verify images
# 验证图像:
# 打开图像文件并验证其完整性。
# 获取图像大小并检查是否小于 10 像素。
# 检查图像格式是否在支持的格式列表中。
# 如果图像是 JPEG 格式,检查是否损坏并尝试修复。
im = Image.open(im_file)
im.verify() # PIL verify 检查图像文件是否损坏。如果图像文件损坏,会抛出异常
# 获取图像的实际大小(高度和宽度),并考虑 EXIF 信息中的旋转
shape = exif_size(im) # image size exif_size(im):调用之前定义的函数,根据 EXIF 信息修正图像大小。
shape = (shape[1], shape[0]) # hw
# 确保图像的宽度和高度都大于 9 像素。如果不符合,抛出断言错误。
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
# 检查图像格式是否在支持的格式列表中。如果不在,抛出断言错误。
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}. {FORMATS_HELP_MSG}"
if im.format.lower() in {"jpg", "jpeg"}:
# 打开图像文件并跳到文件末尾的最后两个字节
with open(im_file, "rb") as f:
# f.seek() 是文件操作中用于移动文件指针位置的方法。
# 具体来说,f.seek(-2, 2) 的作用是将文件指针移动到文件末尾的倒数第二个字节。
# f.seek(offset, whence)
# offset:偏移量,表示移动的字节数。
# whence:移动的基准位置,可以取以下值:
# 0(默认值):从文件开头开始计算偏移量。
# 1:从当前文件指针位置开始计算偏移量。
# 2:从文件末尾开始计算偏移量。
f.seek(-2, 2)
# 如果最后两个字节不是 \xff\xd9(JPEG 文件的结束标志),则认为图像损坏。
if f.read() != b"\xff\xd9": # corrupt JPEG
# 使用 ImageOps.exif_transpose 修复图像并保存
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved"
# Verify labels
# 验证标签:
# 检查标签文件是否存在。
# 如果标签文件存在:
# 读取标签文件内容并解析为 NumPy 数组。
# 如果标签包含多边形(segments),将其转换为边界框。
# 如果是关键点检测任务,检查标签的列数是否符合要求(5 + nkpt * ndim)。
# 检查标签的坐标是否在 [0, 1] 范围内。
# 检查标签类别是否超出数据集的类别范围。
# 检查是否有重复的标签并移除。
# 如果标签文件不存在:
# 标记为“标签缺失”(nm = 1)。
# 返回空的标签数组。
# 如果标签为空:
# 标记为“标签为空”(ne = 1)。
# 返回空的标签数组。
if os.path.isfile(lb_file):
nf = 1 # label found 检查标签文件是否存在。如果存在,标记为找到(nf = 1)
# 读取标签文件并解析为二维列表。每一行表示一个标签,每一列是标签的组成部分(如类别、坐标等)。
with open(lb_file) as f:
lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
# 检查标签是否包含多边形分割数据
# 如果标签的列数大于 6 且不是关键点检测任务,则认为是多边形分割数据。
if any(len(x) > 6 for x in lb) and (not keypoint): # is segment
# 提取类别信息并将其转换为 NumPy 数组。
classes = np.array([x[0] for x in lb], dtype=np.float32)
# 将多边形数据转换为边界框格式(segments2boxes 函数)。
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
# 将类别和边界框合并为一个 NumPy 数组
lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
# 将标签数据转换为 NumPy 数组,方便后续处理
lb = np.array(lb, dtype=np.float32)
# 作用:获取标签的数量。如果数量为 0,表示标签为空。
if nl := len(lb):
# 检查标签的格式是否正确。
# 如果是关键点检测任务,标签的列数应为 5 + nkpt * ndim。
if keypoint:
assert lb.shape[1] == (5 + nkpt * ndim), f"labels require {(5 + nkpt * ndim)} columns each"
points = lb[:, 5:].reshape(-1, ndim)[:, :2] # 提取坐标信息
# 否则,标签的列数应为 5(类别、x、y、w、h)。
else:
assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
points = lb[:, 1:]
# 检查标签的坐标是否在合法范围内([0, 1])
# 如果坐标值大于 1 或小于 0,抛出断言错误
assert points.max() <= 1, f"non-normalized or out of bounds coordinates {points[points > 1]}"
assert lb.min() >= 0, f"negative label values {lb[lb < 0]}"
# All labels
# 检查标签中的类别是否超出数据集的类别范围
# 获取标签中最大的类别值
max_cls = lb[:, 0].max() # max label count
# 如果最大类别值大于数据集的类别总数,抛出断言错误
assert max_cls < num_cls, (
f"Label class {int(max_cls)} exceeds dataset class count {num_cls}. "
f"Possible class labels are 0-{num_cls - 1}"
)
# 检查标签中是否有重复行,并移除重复的标签
# 使用 np.unique 检查重复行
_, i = np.unique(lb, axis=0, return_index=True)
# 如果有重复行,移除重复的标签,并更新多边形分割数据(如果有)
if len(i) < nl: # duplicate row check
lb = lb[i] # remove duplicates
if segments:
segments = [segments[x] for x in i]
# 记录警告信息
msg = f"{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed"
# 如果标签为空,标记为“标签为空”(ne = 1),并返回一个空的标签数组。
else:
ne = 1 # label empty
lb = np.zeros((0, (5 + nkpt * ndim) if keypoint else 5), dtype=np.float32)
else:
nm = 1 # label missing
lb = np.zeros((0, (5 + nkpt * ndim) if keypoints else 5), dtype=np.float32)
# 处理关键点数据
if keypoint:
# 提取关键点数据并将其重塑为 (nl, nkpt, ndim) 的形状
keypoints = lb[:, 5:].reshape(-1, nkpt, ndim)
if ndim == 2:
# 如果关键点的坐标小于 0,生成一个掩码(kpt_mask)
kpt_mask = np.where((keypoints[..., 0] < 0) | (keypoints[..., 1] < 0), 0.0, 1.0).astype(np.float32)
# 将关键点数据和掩码合并为一个数组
keypoints = np.concatenate([keypoints, kpt_mask[..., None]], axis=-1) # (nl, nkpt, 3)
# 仅保留标签的前 5 列(类别、x、y、w、h)
lb = lb[:, :5]
return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
except Exception as e:
nc = 1
msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
# 返回值:
# im_file:图像文件路径
# lb:标签数组
# shape:图像大小(高度和宽度)
# segments:多边形分割数据(如果有)
# keypoints:关键点数据(如果有)
# nm:标签是否缺失(1 表示缺失,0 表示不缺失)
# nf:标签是否找到(1 表示找到,0 表示未找到)
# ne:标签是否为空(1 表示为空,0 表示不为空)
# nc:是否损坏(1 表示损坏,0 表示未损坏)
# msg:验证消息。
return [None, None, None, None, None, nm, nf, ne, nc, msg]
# 用于在图像上可视化 YOLO 格式的标注信息(边界框和类别标签)。
# 它读取图像及其对应的标注文件,然后在图像上绘制边界框并标注类别名称。
# 边界框的颜色根据类别 ID 分配,文本颜色则根据背景颜色的亮度动态调整,以确保可读性。
# image_path:图像文件路径。
# txt_path:标注文件路径(YOLO 格式)。
# label_map:类别 ID 到类别名称的映射字典。
def visualize_image_annotations(image_path, txt_path, label_map):
"""
Visualizes YOLO annotations (bounding boxes and class labels) on an image.
This function reads an image and its corresponding annotation file in YOLO format, then
draws bounding boxes around detected objects and labels them with their respective class names.
The bounding box colors are assigned based on the class ID, and the text color is dynamically
adjusted for readability, depending on the background color's luminance.
Args:
image_path (str): The path to the image file to annotate, and it can be in formats supported by PIL (e.g., .jpg, .png).
txt_path (str): The path to the annotation file in YOLO format, that should contain one line per object with:
- class_id (int): The class index.
- x_center (float): The X center of the bounding box (relative to image width).
- y_center (float): The Y center of the bounding box (relative to image height).
- width (float): The width of the bounding box (relative to image width).
- height (float): The height of the bounding box (relative to image height).
label_map (dict): A dictionary that maps class IDs (integers) to class labels (strings).
Examples:
>>> label_map = {0: "cat", 1: "dog", 2: "bird"} # It should include all annotated classes details
>>> visualize_image_annotations("path/to/image.jpg", "path/to/annotations.txt", label_map)
"""
import matplotlib.pyplot as plt # 用于绘图
# 从 Ultralytics 的工具中导入颜色函数,用于根据类别 ID 获取颜色。
from ultralytics.utils.plotting import colors
# 使用 PIL 打开图像文件,并将其转换为 NumPy 数组。
img = np.array(Image.open(image_path))
# 获取图像的高度和宽度
img_height, img_width = img.shape[:2]
# 读取 YOLO 格式的标注文件
annotations = []
with open(txt_path) as file:
for line in file:
# 每行标注包含 5 个值:class_id、x_center、y_center、width 和 height。
# 这些值是相对于图像宽度和高度的归一化坐标
# 将归一化坐标转换为绝对坐标(像素值)
# 计算边界框的左上角坐标 (x, y) 和宽度 w、高度 h。
class_id, x_center, y_center, width, height = map(float, line.split())
x = (x_center - width / 2) * img_width
y = (y_center - height / 2) * img_height
w = width * img_width
h = height * img_height
# 将边界框信息和类别 ID 添加到 annotations 列表中
annotations.append((x, y, w, h, int(class_id)))
# 绘制图像和标注
# 创建一个绘图窗口
fig, ax = plt.subplots(1) # Plot the image and annotations
# 遍历标注并绘制边界框, 为每个标注绘制边界框
for x, y, w, h, label in annotations:
# 使用 colors(label, True) 获取类别 ID 对应的颜色
# 将颜色值归一化到 [0, 1] 范围
color = tuple(c / 255 for c in colors(label, True)) # Get and normalize the RGB color
# 使用 plt.Rectangle 创建一个矩形(边界框),并将其添加到绘图中
rect = plt.Rectangle((x, y), w, h, linewidth=2, edgecolor=color, facecolor="none") # Create a rectangle
ax.add_patch(rect)
# 根据背景颜色的亮度动态调整文本颜色
# 使用公式计算颜色的亮度(luminance)
luminance = 0.2126 * color[0] + 0.7152 * color[1] + 0.0722 * color[2] # Formula for luminance
# 如果亮度小于 0.5,使用白色文本;否则使用黑色文本
# 使用 ax.text 在边界框上方绘制类别名称
ax.text(x, y - 5, label_map[label], color="white" if luminance < 0.5 else "black", backgroundcolor=color)
# 在绘图窗口中显示图像及其标注
ax.imshow(img)
plt.show()
# 将单个多边形转换为指定图像大小的二值掩码
def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1):
"""
Convert a list of polygons to a binary mask of the specified image size.
Args:
imgsz (tuple): The size of the image as (height, width).
polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape [N, M], where
N is the number of polygons, and M is the number of points such that M % 2 = 0.
color (int, optional): The color value to fill in the polygons on the mask. Defaults to 1.
downsample_ratio (int, optional): Factor by which to downsample the mask. Defaults to 1.
Returns:
(np.ndarray): A binary mask of the specified image size with the polygons filled in.
"""
# 创建一个与图像大小相同的全零数组,表示掩码
mask = np.zeros(imgsz, dtype=np.uint8)
# 将多边形数据转换为 NumPy 数组,并确保其形状为 (N, M, 2),其中 N 是多边形的数量,M 是每个多边形的顶点数。
polygons = np.asarray(polygons, dtype=np.int32)
polygons = polygons.reshape((polygons.shape[0], -1, 2))
# 填充多边形:使用 OpenCV 的 fillPoly 函数将多边形填充到掩码中。color 参数指定填充的颜色值。
cv2.fillPoly(mask, polygons, color=color)
# 下采样:如果 downsample_ratio 大于 1,则对掩码进行下采样
nh, nw = (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio)
# Note: fillPoly first then resize is trying to keep the same loss calculation method when mask-ratio=1
# 使用 OpenCV 的 resize 函数调整掩码的大小。
# 返回一个二值掩码,其中多边形区域被填充为指定的颜色值。
return cv2.resize(mask, (nw, nh))
# 将一组多边形转换为多个二值掩码。
def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
"""
Convert a list of polygons to a set of binary masks of the specified image size.
Args:
imgsz (tuple): The size of the image as (height, width).
polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape [N, M], where
N is the number of polygons, and M is the number of points such that M % 2 = 0.
color (int): The color value to fill in the polygons on the masks.
downsample_ratio (int, optional): Factor by which to downsample each mask. Defaults to 1.
Returns:
(np.ndarray): A set of binary masks of the specified image size with the polygons filled in.
"""
# 循环处理每个多边形:
# 对于每个多边形 x,调用 polygon2mask 函数生成一个二值掩码。
# 使用列表推导式生成所有多边形的掩码。
# 返回值:将生成的掩码列表转换为 NumPy 数组并返回。
return np.array([polygon2mask(imgsz, [x.reshape(-1)], color, downsample_ratio) for x in polygons])
# 处理多边形重叠问题,生成一个掩码,其中每个像素值表示该像素属于的多边形索引。
def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
"""Return a (640, 640) overlap mask."""
# 创建一个全零数组,表示最终的掩码。
# 如果多边形数量超过 255,使用 np.int32 类型;否则使用 np.uint8 类型。
masks = np.zeros(
(imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
dtype=np.int32 if len(segments) > 255 else np.uint8,
)
areas = []
ms = []
# 生成单个多边形掩码
# 对每个多边形生成一个二值掩码。
# 计算每个掩码的面积(即掩码中值为 1 的像素数)。
for si in range(len(segments)):
mask = polygon2mask(imgsz, [segments[si].reshape(-1)], downsample_ratio=downsample_ratio, color=1)
ms.append(mask.astype(masks.dtype))
areas.append(mask.sum())
# 按面积从大到小对多边形进行排序
areas = np.asarray(areas)
index = np.argsort(-areas)
ms = np.array(ms)[index]
# 对每个多边形掩码乘以其索引值(i + 1),并将其累加到最终掩码中。
# 使用 np.clip 确保掩码中的值不超过当前索引值。
for i in range(len(segments)):
mask = ms[i] * (i + 1)
masks = masks + mask
masks = np.clip(masks, a_min=0, a_max=i + 1)
# 返回两个值:
# masks:最终的掩码,其中每个像素值表示该像素属于的多边形索引。
# index:多边形的排序索引。
return masks, index
# 在指定目录中查找与数据集关联的 YAML 文件。
def find_dataset_yaml(path: Path) -> Path:
"""
Find and return the YAML file associated with a Detect, Segment or Pose dataset.
This function searches for a YAML file at the root level of the provided directory first, and if not found, it
performs a recursive search. It prefers YAML files that have the same stem as the provided path. An AssertionError
is raised if no YAML file is found or if multiple YAML files are found.
Args:
path (Path): The directory path to search for the YAML file.
Returns:
(Path): The path of the found YAML file.
"""
# 首先在目录的根级别搜索 .yaml 文件。如果没有找到,再进行递归搜索
files = list(path.glob("*.yaml")) or list(path.rglob("*.yaml")) # try root level first and then recursive
# 断言:必须找到至少一个 YAML 文件:如果没有找到任何 YAML 文件,抛出 AssertionError
assert files, f"No YAML file found in '{path.resolve()}'"
# 如果找到多个 YAML 文件,优先选择与 path 的文件名(stem)匹配的文件
if len(files) > 1:
files = [f for f in files if f.stem == path.stem] # prefer *.yaml files that match
# 断言:必须找到且仅找到一个 YAML 文件: 如果找到多个文件,抛出 AssertionError
assert len(files) == 1, f"Expected 1 YAML file in '{path.resolve()}', but found {len(files)}.\n{files}"
# 返回找到的 YAML 文件路径
return files[0]
# 检查目标检测数据集是否存在,如果不存在则下载并解压,然后解析 YAML 文件并验证数据集结构。
def check_det_dataset(dataset, autodownload=True):
"""
Download, verify, and/or unzip a dataset if not found locally.
This function checks the availability of a specified dataset, and if not found, it has the option to download and
unzip the dataset. It then reads and parses the accompanying YAML data, ensuring key requirements are met and also
resolves paths related to the dataset.
Args:
dataset (str): Path to the dataset or dataset descriptor (like a YAML file).
autodownload (bool, optional): Whether to automatically download the dataset if not found. Defaults to True.
Returns:
(dict): Parsed dataset information and paths.
"""
# 调用 check_file 函数检查数据集文件是否存在
file = check_file(dataset)
# 下载并解压数据集(如果需要)
# Download (optional)
extract_dir = ""
# 如果文件是 ZIP 或 TAR 格式,调用 safe_download 函数下载并解压
if zipfile.is_zipfile(file) or is_tarfile(file):
new_dir = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)
# 使用 find_dataset_yaml 查找解压后的 YAML 文件
file = find_dataset_yaml(DATASETS_DIR / new_dir)
# 设置 extract_dir 为解压目录,并禁用自动下载
extract_dir, autodownload = file.parent, False
# Read YAML
# 使用 yaml_load 函数加载 YAML 文件内容
data = yaml_load(file, append_filename=True) # dictionary
# Checks
# 确保 YAML 文件中包含 train 和 val 键。
# 如果键名是 validation,将其重命名为 val。
for k in "train", "val":
if k not in data:
if k != "val" or "validation" not in data:
raise SyntaxError(
emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs.")
)
LOGGER.info("WARNING ⚠️ renaming data YAML 'validation' key to 'val' to match YOLO format.")
data["val"] = data.pop("validation") # replace 'validation' key with 'val' key
# 确保 YAML 文件中包含 names 或 nc。
if "names" not in data and "nc" not in data:
raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs."))
# 如果两者都存在,确保它们的数量匹配。
if "names" in data and "nc" in data and len(data["names"]) != data["nc"]:
raise SyntaxError(emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match."))
# 如果没有 names,生成默认的类别名称。
if "names" not in data:
data["names"] = [f"class_{i}" for i in range(data["nc"])]
# 如果有 names,更新 nc。
else:
data["nc"] = len(data["names"])
# 调用 check_class_names 函数验证类别名称
data["names"] = check_class_names(data["names"])
# Resolve paths 解析数据集的根路径。
path = Path(extract_dir or data.get("path") or Path(data.get("yaml_file", "")).parent) # dataset root
# 如果路径不是绝对路径,将其解析为绝对路径
if not path.is_absolute():
path = (DATASETS_DIR / path).resolve()
# Set paths 确保所有路径都是绝对路径
data["path"] = path # download scripts
for k in "train", "val", "test", "minival":
if data.get(k): # prepend path
if isinstance(data[k], str):
x = (path / data[k]).resolve()
# 如果路径以 ../ 开头,尝试修正路径
if not x.exists() and data[k].startswith("../"):
x = (path / data[k][3:]).resolve()
data[k] = str(x)
else:
data[k] = [str((path / x).resolve()) for x in data[k]]
# Parse YAML
# 如果验证集路径不存在,尝试下载数据集。
# 支持从 URL 下载 ZIP 文件、运行 bash 脚本或执行 Python 脚本。
# 从 data 字典中获取 val(验证集路径)和 download(下载指令)。
# val:验证集路径,可以是字符串或列表
# s:下载指令,可以是 URL、bash 脚本或 Python 脚本
val, s = (data.get(x) for x in ("val", "download"))
# 将 val 路径解析为绝对路径
if val:
# 如果 val 是列表,则解析每个路径为绝对路径。
# 如果 val 是字符串,则将其转换为列表并解析为绝对路径。
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
# 检查所有验证集路径是否存在:如果任何路径不存在,进入下载流程
if not all(x.exists() for x in val):
# 获取数据集名称
# 从 dataset 中提取数据集名称,去除 URL 中的认证信息
# clean_url 是一个辅助函数,用于清理 URL 中的认证信息(如用户名和密码)。
name = clean_url(dataset) # dataset name with URL auth stripped
# 构建警告信息,指出缺失的路径。
# 使用列表推导式找到第一个不存在的路径,并将其包含在警告信息中。
m = f"\nDataset '{name}' images not found ⚠️, missing path '{[x for x in val if not x.exists()][0]}'"
# 检查是否需要自动下载
# 如果 s 存在且 autodownload 为 True,记录警告信息并尝试下载。
if s and autodownload:
LOGGER.warning(m)
# 否则,记录完整的错误信息并抛出 FileNotFoundError。
else:
m += f"\nNote dataset download directory is '{DATASETS_DIR}'. You can update this in '{SETTINGS_FILE}'"
raise FileNotFoundError(m)
# 记录开始时间,并初始化返回值 r。
t = time.time()
r = None # success
# 根据下载指令类型执行操作
# 作用:如果 s 是一个以 .zip 结尾的 URL,调用 safe_download 函数下载并解压 ZIP 文件。
if s.startswith("http") and s.endswith(".zip"): # URL
# safe_download 是一个辅助函数,用于安全地下载文件并解压到指定目录
safe_download(url=s, dir=DATASETS_DIR, delete=True)
# 如果 s 是一个以 bash 开头的脚本,运行该脚本。
elif s.startswith("bash "): # bash script
LOGGER.info(f"Running {s} ...")
# 使用 os.system 执行 bash 脚本。
r = os.system(s)
else: # python script
exec(s, {"yaml": data})
# 作用:记录下载操作的耗时和结果。
# 计算下载操作的耗时。
# 根据返回值 r 判断下载是否成功,并记录相应的信息。
dt = f"({round(time.time() - t, 1)}s)"
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in {0, None} else f"failure {dt} ❌"
LOGGER.info(f"Dataset download {s}\n")
# 作用:根据类别名称是否包含 ASCII 字符,下载相应的字体文件。
# 如果类别名称只包含 ASCII 字符,下载 Arial.ttf。
# 否则,下载 Arial.Unicode.ttf。
# check_font 是一个辅助函数,用于检查并下载字体文件。
check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf") # download fonts
# 返回解析后的数据集信息。
return data # dictionary
# 检查分类数据集(如 ImageNet)是否存在,如果不存在则尝试下载。
def check_cls_dataset(dataset, split=""):
"""
Checks a classification dataset such as Imagenet.
This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information.
If the dataset is not found locally, it attempts to download the dataset from the internet and save it locally.
Args:
dataset (str | Path): The name of the dataset.
split (str, optional): The split of the dataset. Either 'val', 'test', or ''. Defaults to ''.
Returns:
(dict): A dictionary containing the following keys:
- 'train' (Path): The directory path containing the training set of the dataset.
- 'val' (Path): The directory path containing the validation set of the dataset.
- 'test' (Path): The directory path containing the test set of the dataset.
- 'nc' (int): The number of classes in the dataset.
- 'names' (dict): A dictionary of class names in the dataset.
"""
# Download (optional if dataset=https://file.zip is passed directly)
# 下载数据集(如果需要)
# 如果 dataset 是一个 URL,下载并解压
if str(dataset).startswith(("http:/", "https:/")):
dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False)
# 如果 dataset 是一个压缩文件,解压到指定目录
elif Path(dataset).suffix in {".zip", ".tar", ".gz"}:
file = check_file(dataset)
dataset = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)
# 解析数据集路径:确保 dataset 是一个目录路径
dataset = Path(dataset)
data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()
# 检查数据集是否存在:如果数据集不存在,尝试下载
if not data_dir.is_dir():
LOGGER.warning(f"\nDataset not found ⚠️, missing path {data_dir}, attempting download...")
t = time.time()
if str(dataset) == "imagenet":
subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
else:
url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{dataset}.zip"
download(url, dir=data_dir.parent)
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
LOGGER.info(s)
# 解析数据集结构:确定训练集、验证集和测试集的路径。
train_set = data_dir / "train"
val_set = (
data_dir / "val"
if (data_dir / "val").exists()
else data_dir / "validation"
if (data_dir / "validation").exists()
else None
) # data/test or data/val
# 检查 data_dir 目录下是否存在 test 文件夹。
# 如果 data_dir/test 存在,则将 test_set 设置为该路径。
# 如果不存在,则将 test_set 设置为 None。
test_set = data_dir / "test" if (data_dir / "test").exists() else None # data/val or data/test
# 处理 split 参数: 如果指定的 split 不存在,尝试使用其他分割。
# 如果 split 参数为 "val" 且 val_set 不存在,则记录警告信息并尝试使用 test_set。
if split == "val" and not val_set:
LOGGER.warning("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.")
# 如果 split 参数为 "test" 且 test_set 不存在,则记录警告信息并尝试使用 val_set。
elif split == "test" and not test_set:
LOGGER.warning("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.")
# 统计类别数量和名称:统计训练集中的类别数量和名称。
# 使用 (data_dir / "train").glob("*") 获取训练集目录下的所有子目录。
# 过滤出所有是目录的子目录,统计其数量(nc)。
nc = len([x for x in (data_dir / "train").glob("*") if x.is_dir()]) # number of classes
# 提取每个子目录的名称(类别名称),并将其排序后存储到 names 列表中。
names = [x.name for x in (data_dir / "train").iterdir() if x.is_dir()] # class names list
# 将 names 列表转换为字典,键为类别索引,值为类别名称。
names = dict(enumerate(sorted(names)))
# Print to console
# 打印数据集信息:打印每个数据集分割的详细信息,包括文件数量和类别数量。
# 遍历 train、val 和 test 分割
for k, v in {"train": train_set, "val": val_set, "test": test_set}.items():
# 使用 colorstr 函数为日志信息添加颜色
prefix = f"{colorstr(f'{k}:')} {v}..."
# 如果分割路径 v 为 None,记录其路径
if v is None:
LOGGER.info(prefix)
# 如果分割路径存在
else:
# 使用 rglob 获取该路径下所有文件
files = [path for path in v.rglob("*.*") if path.suffix[1:].lower() in IMG_FORMATS]
# 过滤出文件扩展名在 IMG_FORMATS 中的文件(图像文件)
nf = len(files) # number of files
# 统计文件数量(nf)和文件所在的目录数量(nd)
nd = len({file.parent for file in files}) # number of directories
# 如果没有找到任何图像文件
if nf == 0:
# 如果是训练集,抛出 FileNotFoundError
if k == "train":
raise FileNotFoundError(emojis(f"{dataset} '{k}:' no training images found ❌ "))
# 如果是其他分割,记录警告信息
else:
LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: WARNING ⚠️ no images found")
elif nd != nc:
LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: ERROR ❌️ requires {nc} classes, not {nd}")
else:
LOGGER.info(f"{prefix} found {nf} images in {nd} classes ✅ ")
# 返回数据集信息:返回数据集的路径、类别数量和类别名称
# 返回值:
# train:训练集路径。
# val:验证集路径。
# test:测试集路径。
# nc:类别数量。
# names:类别名称字典。
return {"train": train_set, "val": val_set, "test": test_set, "nc": nc, "names": names}
# 用于生成 HUB 数据集的统计信息(JSON 文件)和 -hub 数据集目录
# 它支持多种任务(如目标检测、分割、关键点检测和分类),并提供了数据集统计和图像压缩功能。
class HUBDatasetStats:
"""
A class for generating HUB dataset JSON and `-hub` dataset directory.
Args:
path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip). Default is 'coco8.yaml'.
task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Default is 'detect'.
autodownload (bool): Attempt to download dataset if not found locally. Default is False.
Example:
Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets
i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.
```python
from ultralytics.data.utils import HUBDatasetStats
stats = HUBDatasetStats("path/to/coco8.zip", task="detect") # detect dataset
stats = HUBDatasetStats("path/to/coco8-seg.zip", task="segment") # segment dataset
stats = HUBDatasetStats("path/to/coco8-pose.zip", task="pose") # pose dataset
stats = HUBDatasetStats("path/to/dota8.zip", task="obb") # OBB dataset
stats = HUBDatasetStats("path/to/imagenet10.zip", task="classify") # classification dataset
stats.get_json(save=True)
stats.process_images()
```
"""
def __init__(self, path="coco8.yaml", task="detect", autodownload=False):
"""Initialize class."""
# 将路径解析为绝对路径。
path = Path(path).resolve()
# 记录开始检查数据集的日志信息。
LOGGER.info(f"Starting HUB dataset checks for {path}....")
# 设置任务类型(如 detect、segment、pose、classify 或 obb)
self.task = task # detect, segment, pose, classify, obb
# 处理分类任务
# 如果任务是分类(classify),解压文件并调用 check_cls_dataset 函数检查数据集。
if self.task == "classify":
unzip_dir = unzip_file(path)
data = check_cls_dataset(unzip_dir)
data["path"] = unzip_dir # 设置数据集路径。
else: # 处理其他任务(目标检测、分割、关键点检测、OBB
# 如果任务不是分类,调用 _unzip 方法解压文件。
_, data_dir, yaml_path = self._unzip(Path(path))
try:
# Load YAML with checks 加载 YAML 文件并进行检查
data = yaml_load(yaml_path)
data["path"] = "" # strip path since YAML should be in dataset root for all HUB datasets
yaml_save(yaml_path, data)
# 调用 check_det_dataset 函数检查数据集。
data = check_det_dataset(yaml_path, autodownload) # dict
data["path"] = data_dir # YAML path should be set to '' (relative) or parent (absolute)
# 捕获并处理可能出现的异常
except Exception as e:
raise Exception("error/HUB/dataset_stats/init") from e
# 设置 HUB 数据集目录和图像目录
self.hub_dir = Path(f"{data['path']}-hub")
self.im_dir = self.hub_dir / "images"
# 初始化统计信息字典,包含类别数量和类别名称。
self.stats = {"nc": len(data["names"]), "names": list(data["names"].values())} # statistics dictionary
# 保存数据集信息。
self.data = data
# 解压 .zip 文件并查找 YAML 文件。
@staticmethod
def _unzip(path):
"""Unzip data.zip."""
# 检查路径是否为 .zip 文件:如果路径不是 .zip 文件,返回 False 和路径。
if not str(path).endswith(".zip"): # path is data.yaml
return False, None, path
# 调用 unzip_file 函数解压文件到父目录
unzip_dir = unzip_file(path, path=path.parent)
# 确保解压目录存在。
assert unzip_dir.is_dir(), (
f"Error unzipping {path}, {unzip_dir} not found. path/to/abc.zip MUST unzip to path/to/abc/"
)
# 返回解压状态、解压目录和 YAML 文件路径。
return True, str(unzip_dir), find_dataset_yaml(unzip_dir) # zipped, data_dir, yaml_path
# 调用 compress_one_image 函数压缩图像并保存到 HUB 图像目录。
def _hub_ops(self, f):
"""Saves a compressed image for HUB previews."""
compress_one_image(f, self.im_dir / Path(f).name) # save to dataset-hub
# 生成 HUB 数据集的 JSON 文件。
def get_json(self, save=False, verbose=False):
"""Return dataset JSON for Ultralytics HUB."""
# 根据任务类型(detect、segment、obb、pose)处理标签数据
# 将类别标签转换为整数,坐标值保留 4 位小数
def _round(labels):
"""Update labels to integer class and 4 decimal place floats."""
if self.task == "detect":
coordinates = labels["bboxes"]
elif self.task in {"segment", "obb"}: # Segment and OBB use segments. OBB segments are normalized xyxyxyxy
coordinates = [x.flatten() for x in labels["segments"]]
elif self.task == "pose":
n, nk, nd = labels["keypoints"].shape
coordinates = np.concatenate((labels["bboxes"], labels["keypoints"].reshape(n, nk * nd)), 1)
else:
raise ValueError(f"Undefined dataset task={self.task}.")
zipped = zip(labels["cls"], coordinates)
return [[int(c[0]), *(round(float(x), 4) for x in points)] for c, points in zipped]
# 遍历数据集分割
# 遍历 train、val 和 test 分割。
# 初始化统计信息。
for split in "train", "val", "test":
self.stats[split] = None # predefine
path = self.data.get(split)
# Check split
# 如果路径不存在,跳过当前分割。
if path is None: # no split
continue
# 获取分割中的图像文件。
files = [f for f in Path(path).rglob("*.*") if f.suffix[1:].lower() in IMG_FORMATS] # image files in split
if not files: # no images
continue
# Get dataset statistics
# 统计分类数据集: 如果任务是分类,使用 ImageFolder 加载数据集。
# 统计每个类别的图像数量。
# 保存统计信息。
if self.task == "classify":
from torchvision.datasets import ImageFolder # scope for faster 'import ultralytics'
dataset = ImageFolder(self.data[split])
x = np.zeros(len(dataset.classes)).astype(int)
for im in dataset.imgs:
x[im[1]] += 1
self.stats[split] = {
"instance_stats": {"total": len(dataset), "per_class": x.tolist()},
"image_stats": {"total": len(dataset), "unlabelled": 0, "per_class": x.tolist()},
"labels": [{Path(k).name: v} for k, v in dataset.imgs],
}
# 统计其他任务的数据集
else:
from ultralytics.data import YOLODataset
# 如果任务不是分类,使用 YOLODataset 加载数据集。
dataset = YOLODataset(img_path=self.data[split], data=self.data, task=self.task)
# 统计每个类别的实例数量和图像数量
x = np.array(
[
np.bincount(label["cls"].astype(int).flatten(), minlength=self.data["nc"])
for label in TQDM(dataset.labels, total=len(dataset), desc="Statistics")
]
) # shape(128x80)
# 保存统计信息
self.stats[split] = {
"instance_stats": {"total": int(x.sum()), "per_class": x.sum(0).tolist()},
"image_stats": {
"total": len(dataset),
"unlabelled": int(np.all(x == 0, 1).sum()),
"per_class": (x > 0).sum(0).tolist(),
},
"labels": [{Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)],
}
# Save, print and return
# 保存、打印并返回统计信息
# 如果 save 为 True,保存统计信息到 stats.json 文件。
if save:
self.hub_dir.mkdir(parents=True, exist_ok=True) # makes dataset-hub/
stats_path = self.hub_dir / "stats.json"
LOGGER.info(f"Saving {stats_path.resolve()}...")
with open(stats_path, "w") as f:
json.dump(self.stats, f) # save stats.json
# 如果 verbose 为 True,打印统计信息。
if verbose:
LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False))
# 返回统计信息。
return self.stats
# 压缩数据集中的图像并保存到 HUB 图像目录。
def process_images(self):
"""Compress images for Ultralytics HUB."""
from ultralytics.data import YOLODataset # ClassificationDataset
# 创建 HUB 图像目录。
self.im_dir.mkdir(parents=True, exist_ok=True) # makes dataset-hub/images/
# 遍历 train、val 和 test 分割。
for split in "train", "val", "test":
# 如果分割不存在,跳过当前分割。
if self.data.get(split) is None:
continue
# 使用 YOLODataset 加载数据集。
dataset = YOLODataset(img_path=self.data[split], data=self.data)
# 使用多线程池(ThreadPool)并行处理图像压缩。
with ThreadPool(NUM_THREADS) as pool:
# 调用 _hub_ops 方法压缩每个图像并保存到 HUB 图像目录。
for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f"{split} images"):
pass
# 记录完成信息。
LOGGER.info(f"Done. All images saved to {self.im_dir}")
# 返回 HUB 图像目录路径。
return self.im_dir
# 压缩单个图像文件,同时保持其宽高比和质量。如果图像尺寸小于最大尺寸,则不会调整大小。
def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
"""
Compresses a single image file to reduced size while preserving its aspect ratio and quality using either the Python
Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it will not be
resized.
Args:
f (str): The path to the input image file.
f_new (str, optional): The path to the output image file. If not specified, the input file will be overwritten.
max_dim (int, optional): The maximum dimension (width or height) of the output image. Default is 1920 pixels.
quality (int, optional): The image compression quality as a percentage. Default is 50%.
Example:
```python
from pathlib import Path
from ultralytics.data.utils import compress_one_image
for f in Path("path/to/dataset").rglob("*.jpg"):
compress_one_image(f)
```
"""
try: # use PIL
im = Image.open(f)
# 计算图像的宽高比(r)。
r = max_dim / max(im.height, im.width) # ratio
# 如果图像的尺寸大于最大尺寸(max_dim),则调整图像大小。
if r < 1.0: # image too large
im = im.resize((int(im.width * r), int(im.height * r)))
# 使用指定的质量(quality)保存图像。如果 f_new 未指定,则覆盖原文件。
im.save(f_new or f, "JPEG", quality=quality, optimize=True) # save
# 如果 PIL 失败,使用 OpenCV 压缩图像
except Exception as e: # use OpenCV
# 如果 PIL 处理失败,记录警告信息。
LOGGER.info(f"WARNING ⚠️ HUB ops PIL failure {f}: {e}")
im = cv2.imread(f)
im_height, im_width = im.shape[:2]
r = max_dim / max(im_height, im_width) # ratio
if r < 1.0: # image too large
im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
cv2.imwrite(str(f_new or f), im)
# 自动将数据集分割为训练集、验证集和测试集,并将结果保存到 autosplit_*.txt 文件中。支持仅使用带标签的图像。
def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annotated_only=False):
"""
Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files.
Args:
path (Path, optional): Path to images directory. Defaults to DATASETS_DIR / 'coco8/images'.
weights (list | tuple, optional): Train, validation, and test split fractions. Defaults to (0.9, 0.1, 0.0).
annotated_only (bool, optional): If True, only images with an associated txt file are used. Defaults to False.
Example:
```python
from ultralytics.data.utils import autosplit
autosplit()
```
"""
path = Path(path) # images dir
# 使用 rglob 获取路径下的所有图像文件。
files = sorted(x for x in path.rglob("*.*") if x.suffix[1:].lower() in IMG_FORMATS) # image files only
# 计算图像文件的数量。
n = len(files) # number of files
# 设置随机种子为 0,确保结果可复现。
random.seed(0) # for reproducibility
# 使用 random.choices 根据权重(weights)将图像分配到训练集、验证集和测试集。
indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
# 定义三个分割文件名。
txt = ["autosplit_train.txt", "autosplit_val.txt", "autosplit_test.txt"] # 3 txt files
# 如果分割文件已存在,则删除。
for x in txt:
if (path.parent / x).exists():
(path.parent / x).unlink() # remove existing
# 记录分割信息,如果 annotated_only 为 True,则记录只使用带标签的图像
LOGGER.info(f"Autosplitting images from {path}" + ", using *.txt labeled images only" * annotated_only)
# 遍历图像文件和对应的分割索引。
for i, img in TQDM(zip(indices, files), total=n):
# 如果 annotated_only 为 True,检查是否存在对应的标签文件。
if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
# 将图像路径写入对应的分割文件。
with open(path.parent / txt[i], "a") as f:
f.write(f"./{img.relative_to(path.parent).as_posix()}" + "\n") # add image to txt file
# 从指定路径加载缓存文件(*.cache),并将其内容(一个字典)返回。在加载过程中,通过禁用垃圾回收器来减少加载时间
def load_dataset_cache_file(path):
"""Load an Ultralytics *.cache dictionary from path."""
import gc # 导入 Python 的垃圾回收器模块。
# 禁用垃圾回收器。这可以减少加载缓存文件时的开销,从而提高加载速度。这是从 Pull Request #1585 中提到的优化技巧。
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
# 使用 np.load 加载缓存文件。allow_pickle=True 允许加载包含 Python 对象的 NumPy 文件。
# .item() 将加载的内容转换为 Python 字典。
cache = np.load(str(path), allow_pickle=True).item() # load dict
# 重新启用垃圾回收器。
gc.enable()
# 返回加载的缓存字典
return cache
# 将 Ultralytics 数据集的缓存字典(x)保存到指定路径(path),并添加缓存版本信息。如果缓存目录不可写,将记录警告信息。
def save_dataset_cache_file(prefix, path, x, version):
"""Save an Ultralytics dataset *.cache dictionary x to path."""
x["version"] = version # 在缓存字典中添加一个键值对,表示缓存的版本
# 使用 is_dir_writeable 函数检查缓存目录是否可写。如果不可写,后续操作将不会执行。
if is_dir_writeable(path.parent):
# 如果缓存文件已存在,则删除它。
if path.exists():
path.unlink() # remove *.cache file if exists
# 使用上下文管理器打开文件并保存缓存字典。上下文管理器确保文件正确关闭,避免潜在的文件描述符泄漏。
with open(str(path), "wb") as file: # context manager here fixes windows async np.save bug
np.save(file, x)
LOGGER.info(f"{prefix}New cache created: {path}")
else:
LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")