model.py
ultralytics\engine\model.py
目录
1.所需的库和模块
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import inspect
from pathlib import Path
from typing import Any, Dict, List, Union
import numpy as np
import torch
from PIL import Image
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
from ultralytics.engine.results import Results
from ultralytics.hub import HUB_WEB_ROOT, HUBTrainingSession
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
from ultralytics.utils import (
ARGV,
ASSETS,
DEFAULT_CFG_DICT,
LOGGER,
RANK,
SETTINGS,
callbacks,
checks,
emojis,
yaml_load,
)
2.class Model(nn.Module):
# 这段代码定义了一个名为 Model 的类,继承自 torch.nn.Module ,用于封装和操作 YOLO 模型的各种功能,包括模型加载、训练、预测、导出等。
# 定义了一个名为 Model 的类,继承自 PyTorch 的 nn.Module ,表示这是一个神经网络模型类。
class Model(nn.Module):
# 用于实现 YOLO 模型的基类,统一不同模型类型的 API。
# 此类为与 YOLO 模型相关的各种操作提供通用接口,例如训练、验证、预测、导出和基准测试。它处理不同类型的模型,包括从本地文件、Ultralytics HUB 或 Triton Server 加载的模型。
# 方法:
# __call__:预测方法的别名,使模型实例可调用。
# _new:根据配置文件初始化新模型。
# _load:从检查点文件加载模型。
# _check_is_pytorch_model:确保模型是 PyTorch 模型。
# reset_weights:将模型的权重重置为其初始状态。
# load:从指定文件加载模型权重。
# save:将模型的当前状态保存到文件。
# info:记录或返回有关模型的信息。
# fuse:融合 Conv2d 和 BatchNorm2d 层以优化推理。
# predict:执行对象检测预测。
# track:执行对象跟踪。
# val:在数据集上验证模型。
# benchmark:在各种导出格式上对模型进行基准测试。
# export:将模型导出为不同格式。
# train:在数据集上训练模型。
# tune:执行超参数调整。
# _apply:将函数应用于模型的张量。
# add_callback:为事件添加回调函数。
# clear_callback:清除事件的所有回调。
# reset_callbacks:将所有回调重置为其默认函数。
"""
A base class for implementing YOLO models, unifying APIs across different model types.
This class provides a common interface for various operations related to YOLO models, such as training,
validation, prediction, exporting, and benchmarking. It handles different types of models, including those
loaded from local files, Ultralytics HUB, or Triton Server.
Attributes:
callbacks (Dict): A dictionary of callback functions for various events during model operations.
predictor (BasePredictor): The predictor object used for making predictions.
model (nn.Module): The underlying PyTorch model.
trainer (BaseTrainer): The trainer object used for training the model.
ckpt (Dict): The checkpoint data if the model is loaded from a *.pt file.
cfg (str): The configuration of the model if loaded from a *.yaml file.
ckpt_path (str): The path to the checkpoint file.
overrides (Dict): A dictionary of overrides for model configuration.
metrics (Dict): The latest training/validation metrics.
session (HUBTrainingSession): The Ultralytics HUB session, if applicable.
task (str): The type of task the model is intended for.
model_name (str): The name of the model.
Methods:
__call__: Alias for the predict method, enabling the model instance to be callable.
_new: Initializes a new model based on a configuration file.
_load: Loads a model from a checkpoint file.
_check_is_pytorch_model: Ensures that the model is a PyTorch model.
reset_weights: Resets the model's weights to their initial state.
load: Loads model weights from a specified file.
save: Saves the current state of the model to a file.
info: Logs or returns information about the model.
fuse: Fuses Conv2d and BatchNorm2d layers for optimized inference.
predict: Performs object detection predictions.
track: Performs object tracking.
val: Validates the model on a dataset.
benchmark: Benchmarks the model on various export formats.
export: Exports the model to different formats.
train: Trains the model on a dataset.
tune: Performs hyperparameter tuning.
_apply: Applies a function to the model's tensors.
add_callback: Adds a callback function for an event.
clear_callback: Clears all callbacks for an event.
reset_callbacks: Resets all callbacks to their default functions.
Examples:
>>> from ultralytics import YOLO
>>> model = YOLO("yolo11n.pt")
>>> results = model.predict("image.jpg")
>>> model.train(data="coco8.yaml", epochs=3)
>>> metrics = model.val()
>>> model.export(format="onnx")
"""
# 这段代码定义了 Model 类的初始化方法 __init__ ,用于设置模型的基本属性,并根据输入参数加载或创建 YOLO 模型。
# 定义了 Model 类的初始化方法,接收以下参数 :
# 1.model :模型的路径或名称,默认为 "yolo11n.pt" 。可以是本地文件路径、Ultralytics HUB 模型标识符或 Triton Server 模型。
# 2.task :模型的任务类型(如检测、分割等),默认为 None 。
# 3.verbose :是否启用详细输出,默认为 False 。
def __init__(
self,
model: Union[str, Path] = "yolo11n.pt",
task: str = None,
verbose: bool = False,
) -> None:
# 初始化 YOLO 模型类的新实例。
# 此构造函数根据提供的模型路径或名称设置模型。它处理各种类型的模型源,包括本地文件、Ultralytics HUB 模型和 Triton Server 模型。该方法初始化模型的几个重要属性,并为训练、预测或导出等操作做好准备。
# 引发:
# FileNotFoundError:如果指定的模型文件不存在或无法访问。
# ValueError:如果模型文件或配置无效或不受支持。
# ImportError:如果未安装特定模型类型(如 HUB SDK)所需的依赖项。
"""
Initializes a new instance of the YOLO model class.
This constructor sets up the model based on the provided model path or name. It handles various types of
model sources, including local files, Ultralytics HUB models, and Triton Server models. The method
initializes several important attributes of the model and prepares it for operations like training,
prediction, or export.
Args:
model (Union[str, Path]): Path or name of the model to load or create. Can be a local file path, a
model name from Ultralytics HUB, or a Triton Server model.
task (str | None): The task type associated with the YOLO model, specifying its application domain.
verbose (bool): If True, enables verbose output during the model's initialization and subsequent
operations.
Raises:
FileNotFoundError: If the specified model file does not exist or is inaccessible.
ValueError: If the model file or configuration is invalid or unsupported.
ImportError: If required dependencies for specific model types (like HUB SDK) are not installed.
Examples:
>>> model = Model("yolo11n.pt")
>>> model = Model("path/to/model.yaml", task="detect")
>>> model = Model("hub_model", verbose=True)
"""
# 调用父类 nn.Module 的初始化方法,确保继承了 PyTorch 模块的基本功能。
super().__init__()
# 初始化回调函数列表,加载默认的回调函数。这些回调函数用于在模型训练、预测等过程中 执行特定操作 。
self.callbacks = callbacks.get_default_callbacks()
# 这段代码是 Model 类初始化方法中的一部分,主要功能是初始化类的各种属性,并对输入的 model 参数进行预处理。
# 初始化 predictor 属性为 None 。 predictor 是一个 用于执行模型预测的对象 ,后续可能会根据需要加载或重用。
self.predictor = None # reuse predictor
# 初始化 model 属性为 None 。 model 是 模型的核心对象 ,后续会根据输入的 model 参数加载或创建。
self.model = None # model object
# 初始化 trainer 属性为 None 。 trainer 是用于 训练模型的对象 ,后续可能会根据需要加载或初始化。
self.trainer = None # trainer object
# 初始化 ckpt 属性为一个空字典。 ckpt 用于 存储从 .pt 文件加载的模型检查点 (checkpoint)信息,例如模型权重、训练参数等。
self.ckpt = {} # if loaded from *.pt
# 初始化 cfg 属性为 None 。 cfg 用于 存储从 .yaml 文件加载的模型配置信息 ,例如模型结构、训练参数等。
self.cfg = None # if loaded from *.yaml
# 初始化 ckpt_path 属性为 None 。 ckpt_path 用于 存储加载的检查点文件路径 ,便于后续引用。
self.ckpt_path = None
# 初始化 overrides 属性为一个空字典。 overrides 用于 存储覆盖默认行为的参数 ,例如在训练过程中覆盖默认的训练配置。
self.overrides = {} # overrides for trainer object
# 初始化 metrics 属性为 None 。 metrics 用于 存储模型训练或验证过程中的性能指标 ,例如准确率、损失值等。
self.metrics = None # validation/training metrics
# 初始化 session 属性为 None 。 session 用于 存储与 Ultralytics HUB 相关的会话信息 ,例如从 HUB 加载模型时的会话状态。
self.session = None # HUB session
# 将传入的 task 参数赋值给 self.task 属性。 task 表示 模型的任务类型 ,例如目标检测( detect )、分割( segment )、分类( classify )等。
self.task = task # task type
# 将输入的 model 参数转换为字符串,并去除首尾的空白字符。这一步确保了 model 参数的格式是统一的,便于后续处理。
model = str(model).strip()
# 这段代码的作用是初始化 Model 类的各个属性,并对输入的 model 参数进行预处理。这些属性在后续的模型加载、训练、预测等操作中会被逐步填充和使用。通过将这些属性初始化为 None 或空值,代码确保了类的状态在初始化时是干净的,避免了潜在的冲突或错误。
# 这段代码的作用是检查输入的 model 参数是否指向一个 Ultralytics HUB 模型或 Triton Server 模型,并根据不同的来源执行相应的处理逻辑。
# Check if Ultralytics HUB model from https://hub.ultralytics.com
# 检查是否为 Ultralytics HUB 模型。
# 调用 self.is_hub_model(model) 方法,判断输入的 model 是否为一个 Ultralytics HUB 模型。 is_hub_model 方法通常会检查 model 是否包含特定的 URL 格式(例如以 https://hub.ultralytics.com/models/ 开头)。 如果返回 True ,则表示 model 是一个 HUB 模型。
if self.is_hub_model(model):
# Fetch model from HUB
# 调用 checks.check_requirements 方法,确保安装了 hub-sdk ,并且版本不低于 0.0.12 。 这是为了确保代码能够与 Ultralytics HUB 交互。
checks.check_requirements("hub-sdk>=0.0.12")
# 使用 HUBTrainingSession.create_session(model) 创建一个 HUB 会话。 create_session 方法会根据输入的 model (HUB 模型的 URL 或标识符)从 HUB 获取模型文件和相关的训练参数。 返回的 session 对象包含了 模型文件路径 和 其他训练相关的配置 。
session = HUBTrainingSession.create_session(model)
# 从 session 中提取 模型文件路径 ,并将其赋值给 model 变量。 这一步确保后续代码可以使用本地路径来加载模型文件。
model = session.model_file
# 检查 session 是否包含从 HUB 发送的 训练参数 ( train_args )。
if session.train_args: # training sent from HUB
# 如果存在,则将 session 存储到 self.session 中,以便后续可以使用这些训练参数。 这允许模型在初始化时直接加载 HUB 提供的训练配置,而无需用户手动指定。
self.session = session
# Check if Triton Server model
# 检查是否为 Triton Server 模型。
# 调用 self.is_triton_model(model) 方法,判断输入的 model 是否为一个 Triton Server 模型。 is_triton_model 方法通常会检查 model 是否是一个有效的 Triton Server URL(例如以 http:// 或 grpc:// 开头)。 如果返回 True ,则表示 model 是一个 Triton Server 模型。
elif self.is_triton_model(model):
# 将 model 的值同时赋给 self.model_name 和 self.model 。 self.model_name 用于 存储模型的名称或路径 。 self.model 用于 后续加载模型时引用 。
self.model_name = self.model = model
# 设置 self.overrides["task"] 的值。 如果 task 参数被明确指定,则使用其值。 否则,默认设置为 "detect" (目标检测任务)。
self.overrides["task"] = task or "detect" # set `task=detect` if not explicitly set
# 如果 model 是一个 Triton Server 模型,则直接返回,不再执行后续代码。 这是因为 Triton Server 模型的加载方式与其他模型不同,后续代码不再适用。
return
# 这段代码的核心功能是根据输入的 model 参数来源(Ultralytics HUB 或 Triton Server)执行不同的处理逻辑。 Ultralytics HUB 模型:从 HUB 获取模型文件和训练参数。如果 HUB 提供了训练参数,则存储到 self.session 中,以便后续使用。Triton Server 模型:设置模型名称和任务类型。直接返回,跳过后续代码,因为 Triton Server 模型的加载方式与其他模型不同。这种设计允许 Model 类灵活支持多种模型来源,同时确保代码的逻辑清晰且易于扩展。
# 这段代码是 Model 类初始化方法的一部分,用于根据输入的 model 参数加载或创建一个新的 YOLO 模型实例。
# Load or create new YOLO model
# 检查 model 参数的文件扩展名是否为 .yaml 或 .yml 。 如果是,表示用户提供了一个模型配置文件,需要根据该配置文件创建一个新的 YOLO 模型实例。
if Path(model).suffix in {".yaml", ".yml"}:
# 调用私有方法 _new ,传入 model (配置文件路径)、 task (任务类型)和 verbose (是否启用详细输出)。 _new 方法的作用是根据配置文件加载模型架构和相关参数,并初始化模型对象。它通常会执行以下操作 :
# 加载 .yaml 文件中的模型配置。
# 根据配置文件创建模型实例。
# 设置任务类型(如检测、分割等)。
# 如果启用了详细输出,打印模型相关信息。
self._new(model, task=task, verbose=verbose)
# 如果 model 的扩展名不是 .yaml 或 .yml ,则假设它是一个预训练模型文件(如 .pt 文件)。
else:
# 调用私有方法 _load ,传入 model (模型文件路径)和 task (任务类型)。 _load 方法的作用是从预训练模型文件加载模型权重和配置,并初始化模型对象。它通常会执行以下操作 :
# 检查模型文件是否存在。
# 加载模型权重和训练参数。
# 设置任务类型。
# 如果需要,恢复训练状态(如从检查点恢复)。
self._load(model, task=task)
# Delete super().training for accessing self.model.training
# 删除从父类 nn.Module 继承的 self.training 属性。 这是因为 Model 类需要直接访问底层模型对象的 training 属性(即 self.model.training ),而不是父类的 training 属性。 删除后, self.training 不再指向父类的属性,从而避免了潜在的冲突。
del self.training
# 这段代码的核心功能是根据输入的 model 参数的类型(配置文件或预训练模型文件)选择合适的加载方式。如果 model 是 .yaml 或 .yml 文件:调用 _new 方法,根据配置文件创建一个新的 YOLO 模型实例。如果 model 是其他文件(如 .pt 文件):调用 _load 方法,从预训练模型文件加载模型权重和配置。最后,删除 self.training 属性,确保可以直接访问底层模型的 training 状态,而不会与父类的属性冲突。这种设计使得 Model 类能够灵活支持从配置文件创建模型和从预训练文件加载模型两种方式。
# 这段代码实现了 Model 类的初始化逻辑,支持从多种来源加载或创建 YOLO 模型,包括。本地文件:加载 .pt 模型文件或 .yaml 配置文件。 Ultralytics HUB:从 HUB 加载模型并获取训练参数。 Triton Server:支持加载 Triton Server 上的模型。初始化过程中,还设置了模型的回调函数、任务类型、训练器、预测器等属性,为后续的模型训练、预测和部署做好准备。
# 这段代码定义了 Model 类的 __call__ 方法,它是一个特殊方法,允许类的实例像函数一样被调用。这种设计使得模型实例可以直接用于执行预测操作,而无需显式调用 predict 方法。
# __call__ 方法是一个特殊方法,允许类的实例像函数一样被调用。例如,如果 model 是 Model 类的实例,那么可以直接通过 model() 调用该方法。
# 参数 :
# 1.source :输入源,可以是以下类型之一 : str 或 Path :文件路径或 URL。 int :摄像头设备编号。 Image.Image :PIL 图像对象。 list 或 tuple :图像列表或元组。 np.ndarray :NumPy 数组。 torch.Tensor :PyTorch 张量。
# 2.stream :布尔值,表示是否将输入源视为连续流(如视频流)。默认为 False 。
# 3.**kwargs :其他关键字参数,用于传递给预测方法的额外配置。
# 返回一个列表,包含预测结果。
def __call__(
self,
source: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor] = None,
stream: bool = False,
**kwargs: Any,
) -> list:
# 预测方法的别名,使模型实例可以调用进行预测。
# 此方法通过允许使用所需参数直接调用模型实例,简化了进行预测的过程。
"""
Alias for the predict method, enabling the model instance to be callable for predictions.
This method simplifies the process of making predictions by allowing the model instance to be called
directly with the required arguments.
Args:
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source of
the image(s) to make predictions on. Can be a file path, URL, PIL image, numpy array, PyTorch
tensor, or a list/tuple of these.
stream (bool): If True, treat the input source as a continuous stream for predictions.
**kwargs: Additional keyword arguments to configure the prediction process.
Returns:
(List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a
Results object.
Examples:
>>> model = YOLO("yolo11n.pt")
>>> results = model("https://ultralytics.com/images/bus.jpg")
>>> for r in results:
... print(f"Detected {len(r)} objects in image")
"""
# 调用了 self.predict 方法,并将 source 、 stream 和其他关键字参数传递给它。 self.predict 方法是实际执行预测的核心方法,它会根据输入的 source 类型和配置执行相应的预测逻辑,并返回预测结果。
return self.predict(source, stream, **kwargs)
# 这段代码的作用是为 Model 类实例提供一个便捷的接口,使其可以直接被调用以执行预测任务。通过定义 __call__ 方法,用户可以像调用函数一样调用模型实例,例如 :
# model = Model("yolo11n.pt")
# results = model("path/to/image.jpg")
# 这种设计使得模型的使用更加直观和灵活,同时隐藏了底层的预测逻辑实现细节。
# 这段代码定义了一个静态方法 is_triton_model ,用于判断输入的字符串 model 是否表示一个 Triton Server 模型。
# @staticmethod 是一个装饰器,表示该方法是一个静态方法。静态方法不需要实例化类即可直接调用,并且不会自动接收 self 参数。
@staticmethod
# 参数 :
# 1.model : str ,输入的字符串,表示模型的路径或标识符。
# 返回一个布尔值( bool ),表示输入的字符串是否符合 Triton Server 模型的格式。
def is_triton_model(model: str) -> bool:
# 检查给定的模型字符串是否为 Triton 服务器 URL。
# 此静态方法通过使用 urllib.parse.urlsplit() 解析其组件来确定提供的模型字符串是否代表有效的 Triton 服务器 URL。
"""
Checks if the given model string is a Triton Server URL.
This static method determines whether the provided model string represents a valid Triton Server URL by
parsing its components using urllib.parse.urlsplit().
Args:
model (str): The model string to be checked.
Returns:
(bool): True if the model string is a valid Triton Server URL, False otherwise.
Examples:
>>> Model.is_triton_model("http://localhost:8000/v2/models/yolo11n")
True
>>> Model.is_triton_model("yolo11n.pt")
False
"""
# 导入 urlsplit 函数,用于解析 URL。 urlsplit 可以将 URL 分解为以下组成部分 :
# scheme :协议(如 http 、 https 、 grpc )。
# netloc :网络位置(如域名或 IP 地址)。
# path :路径。
# query :查询字符串。
# fragment :片段标识符。
from urllib.parse import urlsplit
# 使用 urlsplit 函数解析输入的 model 字符串,将其分解为上述组成部分,并将结果存储在变量 url 中。
url = urlsplit(model)
# 根据解析结果判断 model 是否为有效的 Triton Server 模型。
# url.netloc :检查是否有网络位置(域名或 IP 地址)。如果为空,则不是有效的 URL。
# url.path :检查是否有路径。Triton Server 模型通常需要一个路径(如 /v2/models/model_name )。
# url.scheme :检查协议是否为 http 或 grpc 。Triton Server 支持这两种协议。
# 如果上述条件都满足,则返回 True ,表示输入的字符串是一个有效的 Triton Server 模型;否则返回 False 。
return url.netloc and url.path and url.scheme in {"http", "grpc"}
# 这个静态方法的核心功能是判断输入的字符串是否符合 Triton Server 模型的格式。它通过解析 URL 的各个组成部分来验证:是否包含有效的网络位置( netloc )。是否包含路径( path )。协议是否为 http 或 grpc 。例如:输入 "http://localhost:8000/v2/models/yolov8" 会返回 True 。输入 "yolov8.pt" 或 "https://example.com" 会返回 False 。这种实现方式简洁且高效,适用于在模型加载阶段快速判断输入是否为 Triton Server 模型。
# 这段代码定义了一个静态方法 is_hub_model ,用于判断输入的字符串 model 是否表示一个 Ultralytics HUB 模型。
# @staticmethod 是一个装饰器,表示该方法是一个静态方法。静态方法不需要实例化类即可直接调用,并且不会自动接收 self 参数。
@staticmethod
# 参数 :
# 1.model : str ,输入的字符串,表示模型的路径或标识符。
# 返回一个布尔值( bool ),表示输入的字符串是否符合 Ultralytics HUB 模型的格式。
def is_hub_model(model: str) -> bool:
# 检查提供的模型是否为 Ultralytics HUB 模型。
# 此静态方法确定给定的模型字符串是否代表有效的 Ultralytics HUB 模型标识符。
"""
Check if the provided model is an Ultralytics HUB model.
This static method determines whether the given model string represents a valid Ultralytics HUB model
identifier.
Args:
model (str): The model string to check.
Returns:
(bool): True if the model is a valid Ultralytics HUB model, False otherwise.
Examples:
>>> Model.is_hub_model("https://hub.ultralytics.com/models/MODEL")
True
>>> Model.is_hub_model("yolo11n.pt")
False
"""
# 如果输入的 model 字符串以 HUB_WEB_ROOT/models/ 开头,则返回 True ,表示它是一个有效的 Ultralytics HUB 模型。 否则,返回 False 。
# model.startswith(...) :检查输入的字符串 model 是否以特定的前缀开头。
# HUB_WEB_ROOT :这是一个变量,表示 Ultralytics HUB 的根路径。它通常是一个固定的 URL,例如 https://hub.ultralytics.com 。
# f"{HUB_WEB_ROOT}/models/" :通过格式化字符串,构造出 HUB 模型的路径前缀。例如,如果 HUB_WEB_ROOT 是 https://hub.ultralytics.com ,那么完整的前缀将是 `https://hub.ultralytics.com/models/`。
return model.startswith(f"{HUB_WEB_ROOT}/models/")
# 这个静态方法的核心功能是通过检查输入字符串的前缀来判断是否为一个 Ultralytics HUB 模型。具体逻辑如下。构造前缀:根据 HUB_WEB_ROOT 和固定的路径 /models/ 构造出 HUB 模型的路径前缀。检查前缀:使用 startswith 方法检查输入的 model 是否以该前缀开头。返回结果:如果匹配,则返回 True ;否则返回 False 。示例 :假设 HUB_WEB_ROOT 的值为 https://hub.ultralytics.com ,那么:输入 "https://hub.ultralytics.com/models/yolov8n.pt" 会返回 True 。输入 "yolov8n.pt" 或 "https://example.com/models/yolov8n.pt" 会返回 False 。这种实现方式简单且高效,适用于在模型加载阶段快速判断输入是否为 Ultralytics HUB 模型。
# 这段代码定义了 Model 类的私有方法 _new ,用于根据一个 YAML 配置文件创建一个新的 YOLO 模型实例。
# 参数 :
# 1.cfg : str ,YAML 配置文件的路径。
# 2.task : str (可选) ,模型的任务类型(如 detect 、 segment 等)。如果未指定,则会自动推断。
# 3.model (可选) :一个自定义的模型类。如果未提供,则使用默认的模型类。
# 4.verbose : bool ,是否启用详细输出。默认为 False 。3
# 返回值 : None ,因为该方法的主要目的是初始化模型实例,而不是返回值。
def _new(self, cfg: str, task=None, model=None, verbose=False) -> None:
# 初始化新模型并从模型定义中推断任务类型。
# 此方法根据提供的配置文件创建一个新的模型实例。它加载模型配置,如果未指定则推断任务类型,并使用任务图中的适当类初始化模型。
# 引发:
# ValueError:如果配置文件无效或无法推断任务。
# ImportError:如果未安装指定任务所需的依赖项。
"""
Initializes a new model and infers the task type from the model definitions.
This method creates a new model instance based on the provided configuration file. It loads the model
configuration, infers the task type if not specified, and initializes the model using the appropriate
class from the task map.
Args:
cfg (str): Path to the model configuration file in YAML format.
task (str | None): The specific task for the model. If None, it will be inferred from the config.
model (torch.nn.Module | None): A custom model instance. If provided, it will be used instead of creating
a new one.
verbose (bool): If True, displays model information during loading.
Raises:
ValueError: If the configuration file is invalid or the task cannot be inferred.
ImportError: If the required dependencies for the specified task are not installed.
Examples:
>>> model = Model()
>>> model._new("yolo11n.yaml", task="detect", verbose=True)
"""
# 使用 yaml_model_load 函数加载 YAML 配置文件,并将其内容解析为一个字典 cfg_dict 。 这个字典通常包含了模型的架构、训练参数等信息。
# def yaml_model_load(path): -> 用于从 YAML 文件加载 YOLOv8 模型的配置信息。返回 包含模型配置信息的字典 d 。 -> return d
cfg_dict = yaml_model_load(cfg)
# 将输入的 YAML 配置文件路径存储到 self.cfg 属性中,以便后续使用。
self.cfg = cfg
# 如果用户指定了 task 参数,则直接使用它。 如果未指定,则调用 guess_model_task 函数,根据配置文件的内容推断任务类型(如检测、分割等)。
# def guess_model_task(model):
# -> 用于猜测模型的任务类型(如分类、检测、分割等)。该函数通过检查模型的配置、结构或文件名来推断模型的任务类型。如果无法从模型文件名中推断任务类型,假设任务类型为 "detect" 。
# -> return "segment" / return "classify" / return "pose" / return "obb" / return "detect"
self.task = task or guess_model_task(cfg_dict)
# 如果用户提供了 自定义的模型类 ( model 参数),则使用它;否则调用 _smart_load("model") 方法动态加载默认的模型类。 使用加载的模型类和配置字典 cfg_dict 创建模型实例。 如果启用了详细输出( verbose 为 True 且当前进程的 RANK 为 -1 ),则在模型初始化过程中打印详细信息。
self.model = (model or self._smart_load("model"))(cfg_dict, verbose=verbose and RANK == -1) # build model
# 将模型的 配置文件路径 和 任务类型 存储到 self.overrides 字典中,以便后续覆盖默认行为或记录模型状态。
self.overrides["model"] = self.cfg
self.overrides["task"] = self.task
# Below added to allow export from YAMLs
# 将默认的 模型参数 ( DEFAULT_CFG_DICT )与 用户指定的覆盖参数 ( self.overrides )合并。 如果用户指定了某些参数,则优先使用用户指定的参数。 将 合并后的参数 存储到模型的 args 属性中,以便后续使用。
self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args)
# 将 任务类型 ( self.task )存储到模型的 task 属性中,确保模型实例知道其任务类型。
self.model.task = self.task
# 将 YAML 配置文件的路径 存储到 self.model_name 属性中,便于后续引用。
self.model_name = cfg
# 这个方法的核心功能是根据 YAML 配置文件创建一个新的 YOLO 模型实例。它执行了以下步骤。加载 YAML 配置文件并解析为字典。确定任务类型(用户指定或自动推断)。动态加载模型类并创建模型实例。合并默认参数和用户指定的覆盖参数。设置模型的属性(如任务类型、模型名称等)。这种设计允许用户通过配置文件灵活地定义模型的架构和行为,同时支持自定义模型类和覆盖默认参数。
# 这段代码定义了 Model 类的私有方法 _load ,用于从权重文件加载一个预训练的 YOLO 模型。
# 参数 :
# 1.weights : str ,权重文件的路径或 URL。
# 2.task : str (可选) ,模型的任务类型(如 detect 、 segment 等)。如果未指定,则会自动推断。
# 返回值 : None ,因为该方法的主要目的是初始化模型实例,而不是返回值。
def _load(self, weights: str, task=None) -> None:
# 从检查点文件加载模型或从权重文件初始化模型。
# 此方法处理从 .pt 检查点文件或其他权重文件格式加载模型。它根据加载的权重设置模型、任务和相关属性。
# 引发:
# FileNotFoundError:如果指定的权重文件不存在或无法访问。
# ValueError:如果权重文件格式不受支持或无效。
"""
Loads a model from a checkpoint file or initializes it from a weights file.
This method handles loading models from either .pt checkpoint files or other weight file formats. It sets
up the model, task, and related attributes based on the loaded weights.
Args:
weights (str): Path to the model weights file to be loaded.
task (str | None): The task associated with the model. If None, it will be inferred from the model.
Raises:
FileNotFoundError: If the specified weights file does not exist or is inaccessible.
ValueError: If the weights file format is unsupported or invalid.
Examples:
>>> model = Model()
>>> model._load("yolo11n.pt")
>>> model._load("path/to/weights.pth", task="detect")
"""
# 检查 weights 是否以常见的网络协议(如 http 、 https 、 rtsp 等)开头。 如果是,调用 checks.check_file 方法。 下载权重文件到本地目录( SETTINGS["weights_dir"] )。 返回下载后的本地文件路径。
if weights.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")):
# def check_file(file, suffix="", download=True, download_dir=".", hard=True):
# -> 用于检查文件是否存在,如果不存在则尝试下载文件,并返回文件的路径。直接返回 file 。返回下载后的文件路径,将其转换为字符串形式。如果找到文件,返回第一个匹配的文件路径。 如果未找到文件,返回空列表 [] 。
# -> return file / return str(file) / return files[0] if len(files) else [] # return file
weights = checks.check_file(weights, download_dir=SETTINGS["weights_dir"]) # download and return local file
# 调用 checks.check_model_file_from_stem 方法。 如果输入的 weights 是一个不带扩展名的模型名称(如 yolo11n ),会自动添加 .pt 扩展名。 确保权重文件的路径是完整的。
# def check_model_file_from_stem(model="yolo11n"):
# -> 用于根据模型的名称(或“stem”)返回完整的模型文件名。如果条件满足,则使用 Path(model).with_suffix(".pt") 为模型名称添加 .pt 扩展名,并返回完整的文件路径。例如,输入 "yolo11n" 会返回 Path("yolo11n.pt") 。如果条件不满足,则直接返回原始的 model 输入。
# -> return Path(model).with_suffix(".pt") / return model
weights = checks.check_model_file_from_stem(weights) # add suffix, i.e. yolo11n -> yolo11n.pt
# 检查权重文件的扩展名是否为 .pt 。
if Path(weights).suffix == ".pt":
# 如果是,调用 attempt_load_one_weight 方法加载权重文件。 返回 模型对象和检查点信息 ( ckpt )。
# def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False): -> 用于加载单个模型权重文件,并对模型进行一系列的初始化和兼容性处理。返回 处理后的模型 和 检查点数据 。 -> return model, ckpt
self.model, self.ckpt = attempt_load_one_weight(weights)
# 从模型的 args 属性中提取 任务类型 ( task ),并存储到 self.task 中。
self.task = self.model.args["task"]
# 调用 _reset_ckpt_args 方法,过滤检查点中的参数,保留关键参数(如 imgsz 、 data 、 task 等)。 将过滤后的参数存储到 self.overrides 和 self.model.args 中。
self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
# 将 检查点文件的路径存 储到 self.ckpt_path 中。
self.ckpt_path = self.model.pt_path
# 如果权重文件不是 .pt 文件。
else:
# 调用 checks.check_file 方法,确保文件路径有效。
weights = checks.check_file(weights) # runs in all cases, not redundant with above call
# 将 权重文件路径 存储到 self.model 中, self.ckpt 设置为 None 。
self.model, self.ckpt = weights, None
# 如果未指定任务类型( task ),调用 guess_model_task 方法根据权重文件推断任务类型。
# def guess_model_task(model):
# -> 用于猜测模型的任务类型(如分类、检测、分割等)。该函数通过检查模型的配置、结构或文件名来推断模型的任务类型。如果无法从模型文件名中推断任务类型,假设任务类型为 "detect" 。
# -> return "segment" / return "classify" / return "pose" / return "obb" / return "detect"
self.task = task or guess_model_task(weights)
# 将 权重文件路径 存储到 self.ckpt_path 中。
self.ckpt_path = weights
# 将 权重文件路径 和 任务类型 存储到 self.overrides 中。
self.overrides["model"] = weights
self.overrides["task"] = self.task
# 将 权重文件路径 存储到 self.model_name 中,便于后续引用。
self.model_name = weights
# 这个方法的核心功能是从权重文件加载一个预训练的 YOLO 模型,并初始化模型实例。它执行了以下步骤。检查权重文件是否为 URL,并下载到本地(如果需要)。确保权重文件路径完整,并自动添加 .pt 扩展名(如果需要)。加载 .pt 文件:提取模型对象和检查点信息。设置任务类型和覆盖参数。处理非 .pt 文件:确保文件路径有效。推断任务类型(如果未指定)。更新模型的覆盖参数和名称。这种设计使得 _load 方法能够灵活处理不同类型的权重文件(本地文件或网络 URL),并确保模型实例正确初始化。
# 这段代码定义了 Model 类的私有方法 _check_is_pytorch_model ,用于检查当前模型是否为 PyTorch 模型。如果模型不是 PyTorch 模型,该方法会抛出一个 TypeError ,并提供详细的错误信息。
# _check_is_pytorch_model 表示这是一个私有方法(以单下划线开头)。
# 返回值 : None ,因为该方法的主要目的是执行检查,而不是返回值。
def _check_is_pytorch_model(self) -> None:
# 检查模型是否为 PyTorch 模型,如果不是,则引发 TypeError。
# 此方法验证模型是 PyTorch 模块还是 .pt 文件。它用于确保某些需要 PyTorch 模型的操作仅在兼容的模型类型上执行。
# 引发:
# TypeError:如果模型不是 PyTorch 模块或 .pt 文件。错误消息提供有关支持的模型格式和操作的详细信息。
"""
Checks if the model is a PyTorch model and raises a TypeError if it's not.
This method verifies that the model is either a PyTorch module or a .pt file. It's used to ensure that
certain operations that require a PyTorch model are only performed on compatible model types.
Raises:
TypeError: If the model is not a PyTorch module or a .pt file. The error message provides detailed
information about supported model formats and operations.
Examples:
>>> model = Model("yolo11n.pt")
>>> model._check_is_pytorch_model() # No error raised
>>> model = Model("yolo11n.onnx")
>>> model._check_is_pytorch_model() # Raises TypeError
"""
# 检查 self.model 是否为 字符串 或 Path 对象,并且其扩展名是否为 .pt 。 如果满足条件, pt_str 为 True ,表示模型是一个以 .pt 结尾的文件路径。
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt"
# 检查 self.model 是否为 torch.nn.Module 的实例。 如果是, pt_module 为 True ,表示模型是一个 PyTorch 模型对象。
pt_module = isinstance(self.model, nn.Module)
# 如果 self.model 既不是 .pt 文件路径,也不是 nn.Module 实例,则执行以下操作。
if not (pt_module or pt_str):
# 抛出一个 TypeError ,提示用户模型必须是 .pt 格式的 PyTorch 模型才能执行当前方法。 错误信息中还说明了 :
# PyTorch 模型支持训练、验证、预测和导出等操作。
# 导出格式(如 ONNX、TensorRT)仅支持预测和验证模式。
# 如果需要在 CUDA 或 MPS 设备上运行推理,需要显式指定设备参数(如 device=0 )。
raise TypeError(
f"model='{self.model}' should be a *.pt PyTorch model to run this method, but is a different format. " # model='{self.model}' 应为 *.pt PyTorch 模型以运行此方法,但格式不同。
f"PyTorch models can train, val, predict and export, i.e. 'model.train(data=...)', but exported " # PyTorch 模型可以训练、验证、预测和导出,即“model.train(data=...)”,
f"formats like ONNX, TensorRT etc. only support 'predict' and 'val' modes, " # 但导出的格式(如 ONNX、TensorRT 等)仅支持“预测”和“验证”模式,
f"i.e. 'yolo predict model=yolo11n.onnx'.\nTo run CUDA or MPS inference please pass the device " # 即“yolo predict model=yolo11n.onnx”。\n要运行 CUDA 或 MPS 推理,
f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'" # 请在推理命令中直接传递设备参数,即“model.predict(source=..., device=0)”。
)
# 这个方法的核心功能是检查当前模型是否为 PyTorch 模型。它通过以下逻辑实现。检查 self.model 是否为 .pt 文件路径。检查 self.model 是否为 nn.Module 实例。如果两者都不满足,抛出 TypeError 并提供详细的错误信息。示例场景。正常情况:如果 self.model 是一个 .pt 文件路径或 nn.Module 实例,该方法不会抛出异常。异常情况:如果 self.model 是其他格式(如 ONNX 文件),该方法会抛出错误,提示用户模型格式不支持当前操作。用途:这个方法通常在需要执行 PyTorch 模型特定操作(如训练、导出等)时被调用,以确保模型格式正确。
# 这段代码定义了 Model 类的 reset_weights 方法,用于将模型的权重重置为初始状态。
# 定义了一个名为 reset_weights 的方法,该方法属于 Model 类。
# 返回值为 "Model" ,表示该方法返回 Model 类的实例(即当前模型实例)。
def reset_weights(self) -> "Model":
# 将模型的权重重置为初始状态。
# 此方法遍历模型中的所有模块,如果它们具有“reset_parameters”方法,则重置其参数。它还确保所有参数的“requires_grad”都设置为 True,从而可以在训练期间更新它们。
# 引发:
# AssertionError:如果模型不是 PyTorch 模型。
"""
Resets the model's weights to their initial state.
This method iterates through all modules in the model and resets their parameters if they have a
'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True,
enabling them to be updated during training.
Returns:
(Model): The instance of the class with reset weights.
Raises:
AssertionError: If the model is not a PyTorch model.
Examples:
>>> model = Model("yolo11n.pt")
>>> model.reset_weights()
"""
# 调用 _check_is_pytorch_model 方法,检查当前模型是否为 PyTorch 模型。 如果模型不是 PyTorch 模型,会抛出 TypeError ,确保只有 PyTorch 模型可以执行此方法。
self._check_is_pytorch_model()
# 遍历 模型的所有模块 ( modules )。 self.model.modules() 返回一个生成器,包含模型中所有子模块(如层、子网络等)。
for m in self.model.modules():
# 对于每个模块 m ,检查它是否具有 reset_parameters 方法。 如果存在,调用该方法以 重置模块的参数 。 reset_parameters 是 PyTorch 中许多层(如 nn.Conv2d 、 nn.Linear 等)的标准方法,用于将权重重置为初始值。
if hasattr(m, "reset_parameters"):
m.reset_parameters()
# 遍历 模型的所有参数 ( parameters )。
for p in self.model.parameters():
# 将每个参数的 requires_grad 属性设置为 True ,确保这些参数在后续训练中可以更新。 这一步是为了确保模型的所有参数都可以参与梯度计算和优化。
p.requires_grad = True
# 返回当前模型实例( self ),允许链式调用(如 model.reset_weights().train() )。
return self
# 这个方法的核心功能是将模型的权重重置为初始状态,同时确保所有参数可以参与梯度计算。它通过以下步骤实现。检查模型是否为 PyTorch 模型:通过 _check_is_pytorch_model 方法确保模型格式正确。重置模块参数:遍历模型的所有模块,调用 reset_parameters 方法重置权重。启用梯度计算:将所有参数的 requires_grad 属性设置为 True 。返回模型实例:允许链式调用。这种方法适用于在训练开始前或重新训练模型时,将模型权重恢复到初始状态,从而确保训练过程的一致性。
# 这段代码定义了 Model 类的 load 方法,用于从指定的权重文件加载模型参数。
# 定义了一个名为 load 的方法,属于 Model 类。 参数 :
# 1.weights :权重文件的路径,可以是字符串或 Path 对象,默认值为 "yolo11n.pt" 。
# 返回值为 "Model" ,表示该方法返回 Model 类的实例(即当前模型实例)。
def load(self, weights: Union[str, Path] = "yolo11n.pt") -> "Model":
# 将指定权重文件中的参数加载到模型中。
# 此方法支持从文件或直接从权重对象加载权重。它按名称和形状匹配参数并将它们传输到模型。
# 引发:
# AssertionError:如果模型不是 PyTorch 模型。
"""
Loads parameters from the specified weights file into the model.
This method supports loading weights from a file or directly from a weights object. It matches parameters by
name and shape and transfers them to the model.
Args:
weights (Union[str, Path]): Path to the weights file or a weights object.
Returns:
(Model): The instance of the class with loaded weights.
Raises:
AssertionError: If the model is not a PyTorch model.
Examples:
>>> model = Model()
>>> model.load("yolo11n.pt")
>>> model.load(Path("path/to/weights.pt"))
"""
# 调用 _check_is_pytorch_model 方法,确保当前模型是一个 PyTorch 模型。 如果模型不是 PyTorch 模型,会抛出 TypeError ,防止对非 PyTorch 模型执行加载操作。
self._check_is_pytorch_model()
# 检查 weights 是否为 字符串 或 Path 对象。
if isinstance(weights, (str, Path)):
# 如果是,将 权重文件路径 存储到 self.overrides["pretrained"] 中。这一步是为了在分布式数据并行(DDP)训练中记录预训练权重路径。
self.overrides["pretrained"] = weights # remember the weights for DDP training
# 调用 attempt_load_one_weight(weights) 方法。 加载权重文件,并返回 实际的权重对象 和 检查点信息 ( ckpt )。 将返回的权重对象存储到 weights 变量中,检查点信息存储到 self.ckpt 中。
# def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False): -> 用于加载单个模型权重文件,并对模型进行一系列的初始化和兼容性处理。返回 处理后的模型 和 检查点数据 。 -> return model, ckpt
weights, self.ckpt = attempt_load_one_weight(weights)
# 调用模型的 load 方法,将加载的权重应用到模型中。
self.model.load(weights)
# 返回当前模型实例( self ),允许链式调用(例如 : model.load("yolo11n.pt").train() )。
return self
# 这个方法的核心功能是从指定的权重文件加载模型参数,并将其应用到当前模型中。它通过以下步骤实现。检查模型是否为 PyTorch 模型:确保只有 PyTorch 模型可以执行加载操作。处理权重文件路径:如果权重是文件路径,记录路径并加载权重文件。加载权重:将权重应用到模型中。返回模型实例:允许链式调用。这种方法使得用户可以方便地从本地文件或预训练模型加载权重,同时确保模型状态被正确更新。
# 这段代码定义了 Model 类的 save 方法,用于将当前模型的状态保存到指定的文件中。
# 定义了一个名为 save 的方法,属于 Model 类。 参数 :
# 1.filename :保存模型的文件路径,可以是字符串或 Path 对象,默认值为 "saved_model.pt" 。
# 返回值为 None ,因为该方法的主要目的是将模型保存到文件中,而不是返回值。
def save(self, filename: Union[str, Path] = "saved_model.pt") -> None:
# 将当前模型状态保存到文件。
# 此方法将模型的检查点 (ckpt) 导出到指定的文件名。它包括元数据,例如日期、Ultralytics 版本、许可证信息和文档链接。
"""
Saves the current model state to a file.
This method exports the model's checkpoint (ckpt) to the specified filename. It includes metadata such as
the date, Ultralytics version, license information, and a link to the documentation.
Args:
filename (Union[str, Path]): The name of the file to save the model to.
Raises:
AssertionError: If the model is not a PyTorch model.
Examples:
>>> model = Model("yolo11n.pt")
>>> model.save("my_model.pt")
"""
# 调用 _check_is_pytorch_model 方法,确保当前模型是一个 PyTorch 模型。 如果模型不是 PyTorch 模型,会抛出 TypeError ,阻止非 PyTorch 模型执行保存操作。
self._check_is_pytorch_model()
# 导入必要的模块。
# deepcopy 用于创建模型的深拷贝,避免修改原始模型。
from copy import deepcopy
# datetime 用于获取当前时间。
from datetime import datetime
# __version__ 从 ultralytics 包中获取当前版本号。
from ultralytics import __version__
# 创建一个字典 updates ,包含需要保存的额外信息。
updates = {
# 如果 self.model 是一个 nn.Module 实例,使用 deepcopy 创建模型的深拷贝,并将其转换为半精度( half() )。这有助于减小保存文件的大小。 如果 self.model 不是 nn.Module ,直接使用原始模型。
"model": deepcopy(self.model).half() if isinstance(self.model, nn.Module) else self.model,
# 保存当前时间的 ISO 格式字符串。
"date": datetime.now().isoformat(),
# 保存当前 ultralytics 包的版本号。
"version": __version__,
# 保存许可证信息。
"license": "AGPL-3.0 License (https://ultralytics.com/license)",
# 保存文档链接。
"docs": "https://docs.ultralytics.com",
}
# 使用 torch.save 将模型的状态和额外信息保存到指定的文件中。 合并 self.ckpt ( 已有的检查点信息 )和 updates ( 新添加的信息 )。 保存到文件 filename 中。
torch.save({**self.ckpt, **updates}, filename)
# 这个方法的核心功能是将当前模型的状态保存到指定的文件中。它通过以下步骤实现。检查模型是否为 PyTorch 模型:确保只有 PyTorch 模型可以执行保存操作。创建深拷贝并转换为半精度:如果模型是 nn.Module ,创建深拷贝并转换为半精度,以减小保存文件的大小。添加额外信息:保存当前时间、版本号、许可证信息和文档链接。保存到文件:使用 torch.save 将模型状态和额外信息保存到指定文件中。这种方法使得用户可以方便地将模型的状态保存到文件中,便于后续加载和使用。
# 这段代码定义了 Model 类的 info 方法,用于获取并返回模型的相关信息,支持详细模式和简略模式。
# 定义了一个名为 info 的方法,属于 Model 类。 参数 :
# 1.detailed : bool ,是否返回详细信息,默认为 False 。如果设置为 True ,则返回更多关于模型的详细信息,如每一层的参数和结构。
# 2.verbose : bool ,是否直接打印信息,默认为 True 。如果设置为 True ,则直接打印模型信息;如果设置为 False ,则返回信息而不打印。
def info(self, detailed: bool = False, verbose: bool = True):
# 记录或返回模型信息。
# 此方法根据传递的参数提供有关模型的概述或详细信息。它可以控制输出的详细程度并以列表形式返回信息。
# 引发:
# TypeError:如果模型不是 PyTorch 模型。
"""
Logs or returns model information.
This method provides an overview or detailed information about the model, depending on the arguments
passed. It can control the verbosity of the output and return the information as a list.
Args:
detailed (bool): If True, shows detailed information about the model layers and parameters.
verbose (bool): If True, prints the information. If False, returns the information as a list.
Returns:
(List[str]): A list of strings containing various types of information about the model, including
model summary, layer details, and parameter counts. Empty if verbose is True.
Raises:
TypeError: If the model is not a PyTorch model.
Examples:
>>> model = Model("yolo11n.pt")
>>> model.info() # Prints model summary
>>> info_list = model.info(detailed=True, verbose=False) # Returns detailed info as a list
"""
# 调用 _check_is_pytorch_model 方法,确保当前模型是一个 PyTorch 模型。 如果模型不是 PyTorch 模型,会抛出 TypeError ,阻止非 PyTorch 模型执行此方法。 这一步确保了只有 PyTorch 模型可以调用 info 方法。
self._check_is_pytorch_model()
# 调用 底层模型对象 ( self.model )的 info 方法,并将参数 detailed 和 verbose 传递给它。 self.model.info 方法的具体实现取决于底层模型类,通常会根据 detailed 参数返回不同级别的模型信息。
# 如果 detailed=False ,返回模型的基本信息,如模型名称、任务类型等。
# 如果 detailed=True ,返回更详细的模型信息,如每一层的参数、形状、总参数数量等。
# 根据 verbose 参数的值, info 方法可能会直接打印信息,或者返回信息供后续处理。
return self.model.info(detailed=detailed, verbose=verbose)
# 这个方法的核心功能是获取并返回模型的相关信息,支持两种模式。简略模式( detailed=False ):返回模型的基本信息。详细模式( detailed=True ):返回模型的详细信息,包括每一层的参数和结构。此外, info 方法还支持两种输出方式:直接打印( verbose=True ):直接将模型信息打印到控制台。返回信息( verbose=False ):返回模型信息供后续处理或存储。这种方法使得用户可以灵活地获取模型的信息,无论是用于调试、记录还是展示模型结构。
# 这段代码定义了 Model 类的 fuse 方法,用于对 PyTorch 模型进行层融合操作。
# 定义了一个名为 fuse 的方法,属于 Model 类。 该方法没有参数,也没有返回值,其主要目的是对模型进行优化操作。
def fuse(self):
# 融合模型中的 Conv2d 和 BatchNorm2d 层以优化推理。
# 此方法迭代模型的模块并将连续的 Conv2d 和 BatchNorm2d 层融合为单个层。这种融合可以通过减少前向传递期间所需的操作和内存访问次数来显著提高推理速度。
# 融合过程通常涉及将 BatchNorm2d 参数(均值、方差、权重和偏差)折叠到前面的 Conv2d 层的权重和偏差中。这会产生一个 Conv2d 层,该层在一个步骤中同时执行卷积和规范化。
# 引发:
# TypeError:如果模型不是 PyTorch nn.Module。
"""
Fuses Conv2d and BatchNorm2d layers in the model for optimized inference.
This method iterates through the model's modules and fuses consecutive Conv2d and BatchNorm2d layers
into a single layer. This fusion can significantly improve inference speed by reducing the number of
operations and memory accesses required during forward passes.
The fusion process typically involves folding the BatchNorm2d parameters (mean, variance, weight, and
bias) into the preceding Conv2d layer's weights and biases. This results in a single Conv2d layer that
performs both convolution and normalization in one step.
Raises:
TypeError: If the model is not a PyTorch nn.Module.
Examples:
>>> model = Model("yolo11n.pt")
>>> model.fuse()
>>> # Model is now fused and ready for optimized inference
"""
# 调用 _check_is_pytorch_model 方法,确保当前模型是一个 PyTorch 模型。 如果模型不是 PyTorch 模型,会抛出 TypeError ,阻止非 PyTorch 模型执行此方法。 这一步确保了只有 PyTorch 模型可以进行层融合操作。
self._check_is_pytorch_model()
# 调用 底层模型对象 ( self.model )的 fuse 方法,执行层融合操作。 层融合是一种常见的模型优化技术,通常用于将多个连续的层(如 Conv2d 和 BatchNorm2d )合并为一个层。 这种优化可以减少模型的计算量和内存占用,从而提高推理速度。 具体实现细节取决于底层模型类的 fuse 方法,通常会遍历模型的所有模块,查找可以融合的层并进行优化。
self.model.fuse()
# 这个方法的核心功能是对 PyTorch 模型进行层融合操作,以优化模型的推理性能。它通过以下步骤实现。检查模型是否为 PyTorch 模型:确保只有 PyTorch 模型可以执行层融合操作。调用底层模型的 fuse 方法:执行具体的层融合逻辑。使用场景。模型优化:在将模型部署到生产环境之前,通常会进行层融合操作以提高推理速度。减少计算量:通过合并多个层,减少模型的计算量和内存占用。注意事项。 fuse 方法的具体实现取决于底层模型类,因此需要确保底层模型支持该操作。层融合通常只适用于 PyTorch 模型,因此在调用此方法之前,必须确保模型是 PyTorch 模型。
# 这段代码定义了 Model 类的 embed 方法,用于从输入数据中生成嵌入向量(embeddings)。
# 定义了一个名为 embed 的方法,属于 Model 类。 参数 :
# 1.source :输入源,可以是文件路径、URL、摄像头编号、图像列表、NumPy 数组或 PyTorch 张量。
# 2.stream :布尔值,表示是否将输入源视为连续流(如视频流)。默认为 False 。
# 3.**kwargs :其他关键字参数,用于传递给预测方法的额外配置。
# 返回一个列表,包含生成的图像嵌入。
def embed(
self,
source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,
stream: bool = False,
**kwargs: Any,
) -> list:
# 根据提供的源生成图像嵌入。
# 此方法是 'predict()' 方法的包装器,专注于从图像源生成嵌入。它允许通过各种关键字参数自定义嵌入过程。
# 引发:
# AssertionError:如果模型不是 PyTorch 模型。
"""
Generates image embeddings based on the provided source.
This method is a wrapper around the 'predict()' method, focusing on generating embeddings from an image
source. It allows customization of the embedding process through various keyword arguments.
Args:
source (str | Path | int | List | Tuple | np.ndarray | torch.Tensor): The source of the image for
generating embeddings. Can be a file path, URL, PIL image, numpy array, etc.
stream (bool): If True, predictions are streamed.
**kwargs: Additional keyword arguments for configuring the embedding process.
Returns:
(List[torch.Tensor]): A list containing the image embeddings.
Raises:
AssertionError: If the model is not a PyTorch model.
Examples:
>>> model = YOLO("yolo11n.pt")
>>> image = "https://ultralytics.com/images/bus.jpg"
>>> embeddings = model.embed(image)
>>> print(embeddings[0].shape)
"""
# 检查 kwargs 是否包含键 "embed" 。如果未指定,则默认设置为 [len(self.model.model) - 2] 。
# 这里的逻辑是 :如果没有指定嵌入层的索引,则默认使用模型的倒数第二层作为嵌入层。
# len(self.model.model) - 2 表示模型中倒数第二层的索引。
if not kwargs.get("embed"):
kwargs["embed"] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed
# 调用 self.predict 方法,将输入源和关键字参数传递给它。 self.predict 方法是实际执行预测的核心方法,它会根据输入的 source 类型和配置执行相应的预测逻辑。 在这里, kwargs 中的 "embed" 参数会指示 predict 方法生成嵌入,而不是执行常规的预测任务。
# 最终返回一个列表,包含生成的图像嵌入。
return self.predict(source, stream, **kwargs)
# 这个方法的核心功能是从输入源生成图像嵌入(embeddings)。它通过以下步骤实现。检查嵌入层索引:如果用户未指定嵌入层的索引,则默认使用模型的倒数第二层。调用预测方法:将输入源和配置参数传递给 self.predict 方法,生成嵌入。返回嵌入结果:返回一个列表,包含生成的嵌入。使用场景。图像嵌入生成:在需要将图像转换为嵌入向量(如特征提取、相似性搜索等任务)时使用。灵活的嵌入层选择:用户可以通过指定嵌入层的索引来选择不同的层作为嵌入输出。注意事项。默认情况下,嵌入层为模型的倒数第二层。如果需要其他层的嵌入,可以通过 kwargs 显式指定。 self.predict 方法需要支持嵌入生成逻辑,否则可能会抛出错误。
# 这段代码定义了 Model 类的 predict 方法,用于执行模型的预测操作。它支持多种输入源类型,并允许通过关键字参数进行灵活配置。
# 定义了一个名为 predict 的方法,属于 Model 类。 参数 :
# 1.source :输入源,可以是文件路径、URL、摄像头编号、PIL 图像、图像列表、NumPy 数组或 PyTorch 张量。
# 2.stream :布尔值,表示是否将输入源视为连续流(如视频流)。默认为 False 。
# 3.predictor :自定义预测器对象。如果未提供,则使用默认预测器。
# 4.**kwargs :其他关键字参数,用于配置预测过程。
# 返回一个 List[Results] ,包含预测结果。
def predict(
self,
source: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor] = None,
stream: bool = False,
predictor=None,
**kwargs: Any,
) -> List[Results]:
# 使用 YOLO 模型对给定的图像源执行预测。
# 此方法简化了预测过程,允许通过关键字参数进行各种配置。它支持使用自定义预测器或默认预测器方法进行预测。该方法处理不同类型的图像源,并可以在流模式下运行。
# 注意事项:
# - 如果未提供“source”,则默认为 ASSETS 常量并发出警告。
# - 如果尚未存在,该方法将设置一个新的预测器,并在每次调用时更新其参数。
# - 对于 SAM 类型模型,“prompts”可以作为关键字参数传递。
"""
Performs predictions on the given image source using the YOLO model.
This method facilitates the prediction process, allowing various configurations through keyword arguments.
It supports predictions with custom predictors or the default predictor method. The method handles different
types of image sources and can operate in a streaming mode.
Args:
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source
of the image(s) to make predictions on. Accepts various types including file paths, URLs, PIL
images, numpy arrays, and torch tensors.
stream (bool): If True, treats the input source as a continuous stream for predictions.
predictor (BasePredictor | None): An instance of a custom predictor class for making predictions.
If None, the method uses a default predictor.
**kwargs: Additional keyword arguments for configuring the prediction process.
Returns:
(List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a
Results object.
Examples:
>>> model = YOLO("yolo11n.pt")
>>> results = model.predict(source="path/to/image.jpg", conf=0.25)
>>> for r in results:
... print(r.boxes.data) # print detection bounding boxes
Notes:
- If 'source' is not provided, it defaults to the ASSETS constant with a warning.
- The method sets up a new predictor if not already present and updates its arguments with each call.
- For SAM-type models, 'prompts' can be passed as a keyword argument.
"""
# 检查 source 是否为 None 。
if source is None:
# 如果是,则将 source 设置为默认值 ASSETS ,并记录警告信息。 ASSETS 是一个预定义的默认输入源,通常用于测试或演示。
# ASSETS -> 定义默认图像资产的路径。
source = ASSETS
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") # 警告 ⚠️ 缺少“source”。使用“source={source}”。
# 检查当前运行环境是否为命令行界面(CLI)模式。
# 检查脚本名称是否以 "yolo" 或 "ultralytics" 结尾。
# 检查命令行参数中是否包含 "predict" 、 "track" 或相关的模式标志。
# 如果满足条件,则将 is_cli 设置为 True ,表示当前运行在 CLI 模式下。
is_cli = (ARGV[0].endswith("yolo") or ARGV[0].endswith("ultralytics")) and any(
x in ARGV for x in ("predict", "track", "mode=predict", "mode=track")
)
# 定义 默认的预测参数 custom ,包括 :
# conf=0.25 :置信度阈值。
# batch=1 :批量大小。
# save=is_cli :是否保存结果,取决于是否运行在 CLI 模式。
# mode="predict" :预测模式。
custom = {"conf": 0.25, "batch": 1, "save": is_cli, "mode": "predict"} # method defaults
# 将 self.overrides 、 custom 和 kwargs 合并为 最终的参数字典 args ,其中 kwargs 的优先级最高。
args = {**self.overrides, **custom, **kwargs} # highest priority args on the right
# 如果 args 中包含 "prompts" 键,则提取并存储到 prompts 中(用于 SAM 类型模型)。
prompts = args.pop("prompts", None) # for SAM-type models
# 如果 self.predictor 未初始化,则。
if not self.predictor:
# 如果提供了自定义 predictor ,使用它;否则调用 _smart_load("predictor") 动态加载默认预测器 。 使用 args 和 self.callbacks 初始化预测器。
self.predictor = (predictor or self._smart_load("predictor"))(overrides=args, _callbacks=self.callbacks)
# 调用 self.predictor.setup_model 设置模型 ,传递 self.model 和 is_cli 作为参数。
self.predictor.setup_model(model=self.model, verbose=is_cli)
# 如果 self.predictor 已经初始化,则。
else: # only update args if predictor is already setup
# 使用 get_cfg 更新预测器的参数 。
# def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, overrides: Dict = None):
# -> 用于处理和验证配置信息,最终返回一个配置对象。将最终的配置字典 cfg 转换为 IterableSimpleNamespace 对象并返回。 IterableSimpleNamespace 是一个可迭代的命名空间对象,支持通过点符号访问属性(如 cfg.name ),同时也支持字典操作(如 cfg["name"] )。
# -> return IterableSimpleNamespace(**cfg)
self.predictor.args = get_cfg(self.predictor.args, args)
# 如果 args 中包含 "project" 或 "name" , 更新预测器的保存目录 。
if "project" in args or "name" in args:
# def get_save_dir(args, name=None): -> 用于根据输入参数 args 和可选参数 name 生成一个保存目录路径( save_dir )。它的主要功能是根据用户提供的参数动态生成保存目录,并确保目录路径的唯一性。返回最终的保存目录路径,确保其类型为 Path 对象。 -> return Path(save_dir)
self.predictor.save_dir = get_save_dir(self.predictor.args)
# 如果存在 prompts 且预测器支持 set_prompts 方法,则调用该方法设置提示(用于 SAM 类型模型)。
if prompts and hasattr(self.predictor, "set_prompts"): # for SAM-type models
self.predictor.set_prompts(prompts)
# 如果运行在 CLI 模式下,则调用 self.predictor.predict_cli 方法。 否则,调用 self.predictor 的预测方法,传递 source 和 stream 参数。 返回 预测结果 。
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
# 这个方法的核心功能是执行模型的预测操作,支持多种输入源和灵活的配置。它通过以下步骤实现。检查输入源:如果未提供输入源,则使用默认值并记录警告。检测 CLI 模式:根据命令行参数判断是否运行在 CLI 模式下。合并预测参数:将默认参数、覆盖参数和用户提供的参数合并。初始化或更新预测器:如果预测器未初始化,则加载并设置;如果已初始化,则更新参数。设置提示(可选):如果支持,为预测器设置提示。执行预测:根据运行模式调用相应的预测方法,并返回结果。使用场景。通用预测:支持从多种输入源(如图像、视频流、摄像头等)执行预测。CLI 模式支持:在命令行界面下提供特定的行为(如保存结果)。灵活配置:允许用户通过关键字参数自定义预测过程。注意事项。如果需要自定义预测器,可以通过 predictor 参数提供。如果输入源未指定,将使用默认值,但可能会记录警告。
# 这段代码定义了 Model 类的 track 方法,用于执行目标跟踪任务。它支持多种输入源类型,并允许通过关键字参数进行灵活配置。
# 定义了一个名为 track 的方法,属于 Model 类。 参数 :
# 1.source :输入源,可以是文件路径、URL、摄像头编号、图像列表、NumPy 数组或 PyTorch 张量。
# 2.stream :布尔值,表示是否将输入源视为连续流(如视频流)。默认为 False 。
# 3.persist :布尔值,表示是否持久化跟踪器的状态。默认为 False 。
# 4.**kwargs :其他关键字参数,用于配置跟踪过程。
# 返回一个 List[Results] ,包含跟踪结果。
def track(
self,
source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,
stream: bool = False,
persist: bool = False,
**kwargs: Any,
) -> List[Results]:
# 使用已注册的跟踪器对指定的输入源进行对象跟踪。
# 此方法使用模型的预测器和可选的已注册跟踪器执行对象跟踪。它处理各种输入源,例如文件路径或视频流,并支持通过关键字参数进行自定义。该方法注册跟踪器(如果尚未存在),并可以在调用之间保留它们。
# 引发:
# AttributeError:如果预测器没有注册的跟踪器。
# 注释:
# - 此方法为基于 ByteTrack 的跟踪设置默认置信度阈值 0.1。
# - 跟踪模式在关键字参数中明确设置。
# - 对于视频中的跟踪,批量大小设置为 1。
"""
Conducts object tracking on the specified input source using the registered trackers.
This method performs object tracking using the model's predictors and optionally registered trackers. It handles
various input sources such as file paths or video streams, and supports customization through keyword arguments.
The method registers trackers if not already present and can persist them between calls.
Args:
source (Union[str, Path, int, List, Tuple, np.ndarray, torch.Tensor], optional): Input source for object
tracking. Can be a file path, URL, or video stream.
stream (bool): If True, treats the input source as a continuous video stream. Defaults to False.
persist (bool): If True, persists trackers between different calls to this method. Defaults to False.
**kwargs: Additional keyword arguments for configuring the tracking process.
Returns:
(List[ultralytics.engine.results.Results]): A list of tracking results, each a Results object.
Raises:
AttributeError: If the predictor does not have registered trackers.
Examples:
>>> model = YOLO("yolo11n.pt")
>>> results = model.track(source="path/to/video.mp4", show=True)
>>> for r in results:
... print(r.boxes.id) # print tracking IDs
Notes:
- This method sets a default confidence threshold of 0.1 for ByteTrack-based tracking.
- The tracking mode is explicitly set in the keyword arguments.
- Batch size is set to 1 for tracking in videos.
"""
# 检查 self.predictor 是否有 trackers 属性。如果没有,则。
if not hasattr(self.predictor, "trackers"):
# 导入 register_tracker 函数。
from ultralytics.trackers import register_tracker
# 调用 register_tracker 函数,为当前模型注册一个跟踪器。 persist 参数决定是否持久化跟踪器的状态。
register_tracker(self, persist)
# 设置跟踪任务的 置信度阈值 。 如果 kwargs 中已经指定了 "conf" ,则使用指定的值。 否则,默认设置为 0.1 。这是因为基于 ByteTrack 的跟踪方法需要较低置信度的预测结果作为输入。
kwargs["conf"] = kwargs.get("conf") or 0.1 # ByteTrack-based method needs low confidence predictions as input
# 设置跟踪任务的 批量大小 。 如果 kwargs 中已经指定了 "batch" ,则使用指定的值。 否则,默认设置为 1 。这是因为视频跟踪通常以单帧处理,批量大小为 1 。
kwargs["batch"] = kwargs.get("batch") or 1 # batch-size 1 for tracking in videos
# 设置 kwargs 中的 "mode" 为 "track" ,明确指定当前任务为跟踪模式。
kwargs["mode"] = "track"
# 调用 self.predict 方法,将输入源和配置参数传递给它。 self.predict 方法会根据输入的 source 类型和配置执行相应的跟踪逻辑。 返回 跟踪结果 。
return self.predict(source=source, stream=stream, **kwargs)
# 这个方法的核心功能是执行目标跟踪任务,支持多种输入源和灵活的配置。它通过以下步骤实现。检查跟踪器是否注册:如果未注册,则调用 register_tracker 注册跟踪器。设置跟踪参数:置信度阈值默认为 0.1 ,适合低置信度预测。批量大小默认为 1 ,适合视频流处理。调用预测方法:将输入源和配置参数传递给 self.predict 方法,执行跟踪任务。返回跟踪结果:返回一个包含跟踪结果的列表。使用场景。视频目标跟踪:适用于从视频流中跟踪目标。图像序列跟踪:适用于从图像序列中跟踪目标。灵活配置:用户可以通过关键字参数自定义跟踪过程。注意事项。如果需要持久化跟踪器状态,可以通过 persist 参数设置。跟踪任务的置信度阈值较低,以确保更多的预测结果被用于跟踪。
# 这段代码定义了 Model 类的 val 方法,用于执行模型的验证操作。它支持自定义验证器和灵活的配置选项。
# 定义了一个名为 val 的方法,属于 Model 类。 参数 :
# 1.validator :自定义验证器对象。如果未提供,则使用默认的验证器。
# 2.**kwargs :其他关键字参数,用于配置验证过程。
# 返回验证过程的指标( validator.metrics )。
def val(
self,
validator=None,
**kwargs: Any,
):
# 使用指定的数据集和验证配置验证模型。
# 此方法简化了模型验证过程,允许通过各种设置进行自定义。它支持使用自定义验证器或默认验证方法进行验证。该方法结合了默认配置、方法特定的默认值和用户提供的参数来配置验证过程。
# 引发:
# AssertionError:如果模型不是 PyTorch 模型。
"""
Validates the model using a specified dataset and validation configuration.
This method facilitates the model validation process, allowing for customization through various settings. It
supports validation with a custom validator or the default validation approach. The method combines default
configurations, method-specific defaults, and user-provided arguments to configure the validation process.
Args:
validator (ultralytics.engine.validator.BaseValidator | None): An instance of a custom validator class for
validating the model.
**kwargs: Arbitrary keyword arguments for customizing the validation process.
Returns:
(ultralytics.utils.metrics.DetMetrics): Validation metrics obtained from the validation process.
Raises:
AssertionError: If the model is not a PyTorch model.
Examples:
>>> model = YOLO("yolo11n.pt")
>>> results = model.val(data="coco8.yaml", imgsz=640)
>>> print(results.box.map) # Print mAP50-95
"""
# 定义 默认的验证参数 custom ,其中 "rect": True 表示在验证过程中使用矩形裁剪(可能与数据预处理相关)。
custom = {"rect": True} # method defaults
# 合并参数。
# self.overrides :模型的覆盖参数。
# custom :默认的验证参数。
# kwargs :用户提供的关键字参数。
# "mode": "val" :明确指定当前模式为验证模式。
# 合并后的参数中, kwargs 的优先级最高,其次是 custom 和 self.overrides 。
args = {**self.overrides, **custom, **kwargs, "mode": "val"} # highest priority args on the right
# 如果用户提供了自定义的 validator ,则直接使用它;否则调用 _smart_load("validator") 动态加载默认的验证器。 使用 args 和 self.callbacks 初始化验证器。
validator = (validator or self._smart_load("validator"))(args=args, _callbacks=self.callbacks)
# 调用验证器的 __call__ 方法(或直接作为函数调用),将当前模型传递给验证器。 验证器会根据提供的模型和参数执行验证过程。
validator(model=self.model)
# 将验证器的指标( validator.metrics )存储到 self.metrics 中,以便后续访问。
self.metrics = validator.metrics
# 返回 验证器的指标 ,供用户使用。
return validator.metrics
# 这个方法的核心功能是执行模型的验证操作,支持自定义验证器和灵活的配置。它通过以下步骤实现。设置默认参数:定义默认的验证参数(如 "rect": True )。合并参数:将模型的覆盖参数、默认参数和用户提供的参数合并。初始化验证器:如果用户提供了自定义验证器,则使用它;否则加载默认验证器。执行验证:调用验证器的 __call__ 方法,将当前模型传递给验证器。存储和返回指标:将验证器的指标存储到 self.metrics 中,并返回这些指标。使用场景。模型验证:在训练完成后,使用此方法评估模型的性能。灵活配置:用户可以通过 kwargs 提供自定义的验证参数。自定义验证器:用户可以提供自己的验证器,以满足特定需求。注意事项。如果需要自定义验证器,可以通过 validator 参数提供。验证器的具体实现细节取决于底层验证器类,确保其支持所需的验证逻辑。
# 这段代码定义了 Model 类的 benchmark 方法,用于对模型进行性能基准测试。它支持多种配置选项,并通过调用 ultralytics.utils.benchmarks.benchmark 函数来执行实际的测试。
# 定义了一个名为 benchmark 的方法,属于 Model 类。 参数 :
# 1.**kwargs :其他关键字参数,用于配置基准测试过程。
def benchmark(
self,
**kwargs: Any,
):
# 在各种导出格式中对模型进行基准测试以评估性能。
# 此方法评估模型在不同导出格式(例如 ONNX、TorchScript 等)中的性能。
# 它使用 ultralytics.utils.benchmarks 模块中的“benchmark”函数。基准测试使用默认配置值、模型特定参数、方法特定默认值以及任何其他用户提供的关键字参数的组合进行配置。
# 引发:
# AssertionError:如果模型不是 PyTorch 模型。
"""
Benchmarks the model across various export formats to evaluate performance.
This method assesses the model's performance in different export formats, such as ONNX, TorchScript, etc.
It uses the 'benchmark' function from the ultralytics.utils.benchmarks module. The benchmarking is
configured using a combination of default configuration values, model-specific arguments, method-specific
defaults, and any additional user-provided keyword arguments.
Args:
**kwargs: Arbitrary keyword arguments to customize the benchmarking process. These are combined with
default configurations, model-specific arguments, and method defaults. Common options include:
- data (str): Path to the dataset for benchmarking.
- imgsz (int | List[int]): Image size for benchmarking.
- half (bool): Whether to use half-precision (FP16) mode.
- int8 (bool): Whether to use int8 precision mode.
- device (str): Device to run the benchmark on (e.g., 'cpu', 'cuda').
- verbose (bool): Whether to print detailed benchmark information.
Returns:
(Dict): A dictionary containing the results of the benchmarking process, including metrics for
different export formats.
Raises:
AssertionError: If the model is not a PyTorch model.
Examples:
>>> model = YOLO("yolo11n.pt")
>>> results = model.benchmark(data="coco8.yaml", imgsz=640, half=True)
>>> print(results)
"""
# 调用 _check_is_pytorch_model 方法,确保当前模型是一个 PyTorch 模型。 如果模型不是 PyTorch 模型,会抛出 TypeError ,阻止非 PyTorch 模型执行基准测试。
self._check_is_pytorch_model()
# 导入 benchmark 函数,该函数用于执行模型的性能基准测试。 这个函数通常会评估模型在不同条件下的性能,如推理速度、内存占用等。
from ultralytics.utils.benchmarks import benchmark
# 定义默认的 基准测试参数 custom ,其中 "verbose": False 表示默认情况下不打印详细的测试信息。
custom = {"verbose": False} # method defaults
# 合并参数。
# DEFAULT_CFG_DICT :全局默认配置。
# self.model.args :模型自身的参数。
# custom :方法默认参数。
# kwargs :用户提供的关键字参数。
# "mode": "benchmark" :明确指定当前模式为基准测试模式。
# 合并后的参数中, kwargs 的优先级最高,其次是 custom 、 self.model.args 和 DEFAULT_CFG_DICT 。
args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, "mode": "benchmark"}
# 调用 benchmark 函数,执行模型的性能基准测试。 返回 基准测试的结果 。
return benchmark(
# 当前模型实例。
model=self,
# 数据集路径。如果未提供,则使用默认数据集。
data=kwargs.get("data"), # if no 'data' argument passed set data=None for default datasets
# 输入图像尺寸。
imgsz=args["imgsz"],
# 是否使用半精度(FP16)推理。
half=args["half"],
# 是否使用 INT8 量化推理。
int8=args["int8"],
# 运行设备(如 "cpu" 或 "cuda" )。
device=args["device"],
# 是否打印详细的测试信息。
verbose=kwargs.get("verbose"),
)
# 这个方法的核心功能是对模型进行性能基准测试,支持多种配置选项。它通过以下步骤实现。检查模型是否为 PyTorch 模型:确保只有 PyTorch 模型可以执行基准测试。设置默认参数:定义默认的基准测试参数。合并参数:将全局默认配置、模型参数、方法默认参数和用户提供的参数合并。调用基准测试函数:使用 benchmark 函数执行实际的性能测试。返回测试结果:返回基准测试的结果。使用场景。性能评估:在模型开发或优化后,使用此方法评估模型的性能。灵活配置:用户可以通过 kwargs 提供自定义的测试参数。多种推理模式:支持 FP16 和 INT8 量化推理,适用于不同的硬件环境。注意事项。如果需要自定义数据集,可以通过 kwargs 提供 data 参数。如果需要详细输出,可以通过 kwargs 设置 verbose=True 。
# 这段代码定义了 Model 类的 export 方法,用于将模型导出为其他格式(如 ONNX、TorchScript 等),以便用于部署或与其他工具集成。
# 定义了一个名为 export 的方法,属于 Model 类。 参数 :
# 1.**kwargs :其他关键字参数,用于配置导出过程。
# 返回一个字符串,表示导出模型的文件路径。
def export(
self,
**kwargs: Any,
) -> str:
# 将模型导出为适合部署的其他格式。
# 此方法有助于将模型导出为各种格式(例如 ONNX、TorchScript)以进行部署。它使用“Exporter”类进行导出过程,结合特定于模型的覆盖、方法默认值和提供的任何其他参数。
# 引发:
# AssertionError:如果模型不是 PyTorch 模型。
# ValueError:如果指定了不支持的导出格式。
# RuntimeError:如果导出过程由于错误而失败。
"""
Exports the model to a different format suitable for deployment.
This method facilitates the export of the model to various formats (e.g., ONNX, TorchScript) for deployment
purposes. It uses the 'Exporter' class for the export process, combining model-specific overrides, method
defaults, and any additional arguments provided.
Args:
**kwargs: Arbitrary keyword arguments to customize the export process. These are combined with
the model's overrides and method defaults. Common arguments include:
format (str): Export format (e.g., 'onnx', 'engine', 'coreml').
half (bool): Export model in half-precision.
int8 (bool): Export model in int8 precision.
device (str): Device to run the export on.
workspace (int): Maximum memory workspace size for TensorRT engines.
nms (bool): Add Non-Maximum Suppression (NMS) module to model.
simplify (bool): Simplify ONNX model.
Returns:
(str): The path to the exported model file.
Raises:
AssertionError: If the model is not a PyTorch model.
ValueError: If an unsupported export format is specified.
RuntimeError: If the export process fails due to errors.
Examples:
>>> model = YOLO("yolo11n.pt")
>>> model.export(format="onnx", dynamic=True, simplify=True)
'path/to/exported/model.onnx'
"""
# 调用 _check_is_pytorch_model 方法,确保当前模型是一个 PyTorch 模型。 如果模型不是 PyTorch 模型,会抛出 TypeError ,阻止非 PyTorch 模型执行导出操作。
self._check_is_pytorch_model()
# 从当前模块的子模块 exporter 中导入 Exporter 类。 Exporter 类负责执行模型的导出逻辑。
from .exporter import Exporter
# 定义默认的导出参数 custom 。
custom = {
# 从模型参数中获取 输入图像尺寸 。
"imgsz": self.model.args["imgsz"],
# 设置 批量大小 为 1 ,适用于大多数导出场景。
"batch": 1,
# 不指定数据集路径(默认为 None )。
"data": None,
# 不指定设备(默认为 None ),以避免多 GPU 错误。
"device": None, # reset to avoid multi-GPU errors
# 默认不打印详细信息。
"verbose": False,
} # method defaults
# 合并参数。
# self.overrides :模型的覆盖参数。
# custom :默认的导出参数。
# kwargs :用户提供的关键字参数。
# "mode": "export" :明确指定当前模式为导出模式。
# 合并后的参数中, kwargs 的优先级最高,其次是 custom 和 self.overrides 。
args = {**self.overrides, **custom, **kwargs, "mode": "export"} # highest priority args on the right
# 创建 Exporter 实例,传入合并后的参数 args 和回调函数 self.callbacks 。 调用 Exporter 实例的 __call__ 方法(或直接作为函数调用),将当前模型传递给导出器。 导出器会根据提供的模型和参数执行导出逻辑,并返回 导出模型的文件路径 。
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
# 这个方法的核心功能是将模型导出为其他格式,以便用于部署或与其他工具集成。它通过以下步骤实现。检查模型是否为 PyTorch 模型:确保只有 PyTorch 模型可以执行导出操作。设置默认参数:定义默认的导出参数。合并参数:将模型的覆盖参数、默认参数和用户提供的参数合并。调用导出器:使用 Exporter 类执行实际的导出逻辑。返回导出文件路径:返回导出模型的文件路径。使用场景。模型部署:将模型导出为 ONNX、TorchScript 等格式,以便在其他平台上运行。与其他工具集成:导出的模型可以用于推理优化、量化或其他工具的输入。灵活配置:用户可以通过 kwargs 提供自定义的导出参数。注意事项。如果需要指定导出格式(如 ONNX 或 TorchScript),可以通过 kwargs 提供相关参数。导出过程中可能会涉及设备选择(如 CPU 或 GPU),但默认情况下不指定设备以避免多 GPU 错误。
# 这段代码定义了 Model 类的 train 方法,用于执行模型的训练过程。它支持从头开始训练、从检查点恢复训练,以及与 Ultralytics HUB 集成。
# 定义了一个名为 train 的方法,属于 Model 类。 参数 :
# 1.trainer :自定义训练器对象。如果未提供,则使用默认的训练器。
# 2.**kwargs :其他关键字参数,用于配置训练过程。
def train(
self,
trainer=None,
**kwargs: Any,
):
# 使用指定的数据集和训练配置训练模型。
# 此方法通过一系列可自定义的设置促进模型训练。它支持使用自定义训练器或默认训练方法进行训练。该方法处理从检查点恢复训练、与 Ultralytics HUB 集成以及训练后更新模型和配置等场景。
# 使用 Ultralytics HUB 时,如果会话已加载模型,该方法将优先考虑 HUB 训练参数,并在提供本地参数时发出警告。它会检查 pip 更新并结合默认配置、方法特定的默认值和用户提供的参数来配置训练过程。
# 引发:
# AssertionError:如果模型不是 PyTorch 模型。
# PermissionError:如果 HUB 会话存在权限问题。
# ModuleNotFoundError:如果未安装 HUB SDK。
"""
Trains the model using the specified dataset and training configuration.
This method facilitates model training with a range of customizable settings. It supports training with a
custom trainer or the default training approach. The method handles scenarios such as resuming training
from a checkpoint, integrating with Ultralytics HUB, and updating model and configuration after training.
When using Ultralytics HUB, if the session has a loaded model, the method prioritizes HUB training
arguments and warns if local arguments are provided. It checks for pip updates and combines default
configurations, method-specific defaults, and user-provided arguments to configure the training process.
Args:
trainer (BaseTrainer | None): Custom trainer instance for model training. If None, uses default.
**kwargs: Arbitrary keyword arguments for training configuration. Common options include:
data (str): Path to dataset configuration file.
epochs (int): Number of training epochs.
batch_size (int): Batch size for training.
imgsz (int): Input image size.
device (str): Device to run training on (e.g., 'cuda', 'cpu').
workers (int): Number of worker threads for data loading.
optimizer (str): Optimizer to use for training.
lr0 (float): Initial learning rate.
patience (int): Epochs to wait for no observable improvement for early stopping of training.
Returns:
(Dict | None): Training metrics if available and training is successful; otherwise, None.
Raises:
AssertionError: If the model is not a PyTorch model.
PermissionError: If there is a permission issue with the HUB session.
ModuleNotFoundError: If the HUB SDK is not installed.
Examples:
>>> model = YOLO("yolo11n.pt")
>>> results = model.train(data="coco8.yaml", epochs=3)
"""
# 调用 _check_is_pytorch_model 方法,确保当前模型是一个 PyTorch 模型。 如果模型不是 PyTorch 模型,会抛出 TypeError ,阻止非 PyTorch 模型执行训练操作。
self._check_is_pytorch_model()
# 检查是否存在一个有效的 Ultralytics HUB 会话( self.session )且会话中加载了模型。
if hasattr(self.session, "model") and self.session.model.id: # Ultralytics HUB session with loaded model
# 如果存在,且用户提供了本地训练参数( kwargs ),则发出警告,提示将使用 HUB 的训练参数。
if any(kwargs):
LOGGER.warning("WARNING ⚠️ using HUB training arguments, ignoring local training arguments.") # 警告⚠️使用 HUB 训练参数,忽略本地训练参数。
# 使用 HUB 的训练参数覆盖本地参数。
kwargs = self.session.train_args # overwrite kwargs
# 调用 check_pip_update_available 方法,检查是否有可用的 pip 更新。 这一步确保用户使用的是最新版本的依赖库。
# def check_pip_update_available(): -> 用于检查 PyPI 上是否有 ultralytics 包的新版本可用。返回 True ,表示有新版本可用。如果没有新版本可用,或者在检查过程中发生异常,则返回 False 。 -> return True / return False
checks.check_pip_update_available()
# 这段代码是 train 方法中的一部分,用于处理训练参数的加载、合并和优先级排序。
# 加载配置文件。
# 如果 kwargs 中提供了 "cfg" 参数(即配置文件路径),则调用 checks.check_yaml 方法,检查配置文件路径是否有效,并返回标准化的路径。 使用 yaml_load 方法加载 YAML 文件内容并解析为字典。
# 如果未提供 "cfg" 参数,则直接使用 self.overrides (模型的覆盖参数)。
overrides = yaml_load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides
# 定义默认参数。
custom = {
# NOTE: handle the case when 'cfg' includes 'data'. 注意:处理‘cfg’包含‘data’的情况。
# 优先从 overrides 中获取 "data" (数据集路径)。 如果未指定,则使用 DEFAULT_CFG_DICT["data"] (全局默认配置)。 如果仍未指定,则使用 TASK2DATA[self.task] (根据任务类型自动选择的数据集)。
"data": overrides.get("data") or DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task],
# 使用 self.overrides["model"] (模型路径)。
"model": self.overrides["model"],
# 使用 self.task (任务类型)。
"task": self.task,
} # method defaults
# 合并参数。
# 合并 overrides (配置文件参数)、 custom (默认参数)和 kwargs (用户提供的参数)。
# 合并顺序决定了参数的优先级 : overrides 的优先级最低。 custom 的优先级次之。 kwargs 的优先级最高。
# 最后,添加 "mode": "train" ,明确指定当前模式为训练模式。
args = {**overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right
# 处理恢复训练。 如果 args 中包含 "resume" 参数(表示从检查点恢复训练),则将其值设置为 self.ckpt_path (当前检查点路径)。 这一步确保恢复训练时使用正确的检查点路径。
if args.get("resume"):
args["resume"] = self.ckpt_path
# 这段代码的核心功能是处理训练参数的加载、合并和优先级排序。具体步骤如下。加载配置文件:如果提供了配置文件路径( kwargs["cfg"] ),加载并解析为字典。如果未提供,使用模型的覆盖参数( self.overrides )。定义默认参数:提供一组默认参数( custom ),包括数据集路径、模型路径和任务类型。合并参数:将配置文件参数、默认参数和用户提供的参数合并,确保用户提供的参数具有最高优先级。处理恢复训练:如果需要从检查点恢复训练,设置正确的检查点路径。
# 这段代码是 train 方法中的一部分,用于初始化训练器( trainer )并根据需要设置模型。
# 初始化训练器。
# 如果用户提供了自定义的 trainer 对象,则直接使用它。
# 如果未提供,则调用 _smart_load("trainer") 方法动态加载默认的训练器类。 使用 overrides=args 和 _callbacks=self.callbacks 初始化训练器实例。 args 是合并后的训练参数,包含用户提供的参数、默认参数和配置文件参数。 _callbacks=self.callbacks 是训练过程中使用的回调函数。
self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks)
# 检查是否恢复训练。 如果 args 中没有 "resume" 参数(即不是从检查点恢复训练),则手动设置训练器的模型。 如果存在 "resume" 参数,则跳过这一步,因为恢复训练时模型会从检查点加载。
if not args.get("resume"): # manually set model only if not resuming
# 获取模型实例。 调用训练器的 get_model 方法来获取模型实例。
# weights :如果存在检查点( self.ckpt ),则使用当前模型的权重;否则为 None 。
# cfg :模型的配置文件路径( self.model.yaml )。
# 这一步确保训练器使用正确的模型实例和权重。
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
# 更新当前模型实例。 将 训练器的模型实例 ( self.trainer.model )赋值给 self.model 。 这一步确保当前模型实例与训练器使用的模型实例一致。
self.model = self.trainer.model
# 这段代码的核心功能是初始化训练器并根据需要设置模型。具体步骤如下。初始化训练器:如果用户提供了自定义训练器,则直接使用它;否则加载默认训练器。使用合并后的参数和回调函数初始化训练器实例。检查是否恢复训练:如果不是从检查点恢复训练,则手动设置训练器的模型。获取模型实例:调用训练器的 get_model 方法,传入权重和配置文件路径,获取模型实例。更新当前模型实例:确保当前模型实例与训练器使用的模型实例一致。
# 这段代码是 train 方法的最后部分,负责执行训练过程、更新模型状态以及返回训练指标。
# 附加 HUB 会话。 如果存在一个有效的 Ultralytics HUB 会话( self.session ),将其附加到训练器( self.trainer )上。 这使得训练器可以利用 HUB 提供的额外功能,例如自动上传训练结果、管理模型版本等。 如果没有 HUB 会话,这一步不会报错,只是不会附加任何会话。
self.trainer.hub_session = self.session # attach optional HUB session
# 执行训练过程。 调用训练器的 train 方法,开始训练模型。 这一步是训练过程的核心,具体实现细节由训练器类决定,通常包括 :数据加载和预处理。模型训练循环。保存训练过程中的检查点。验证模型性能(如果配置了验证集)。
self.trainer.train()
# Update model and cfg after training
# 更新模型和配置。 这一步仅在主进程( RANK in {-1, 0} )中执行,以避免在分布式训练中重复操作。 RANK 是一个标识符,用于区分主进程和子进程。 -1 表示单进程训练, 0 表示分布式训练中的主进程。
if RANK in {-1, 0}:
# 选择最佳或最后一个检查点。 如果训练器有最佳检查点( self.trainer.best ),且该文件存在,则选择最佳检查点。 否则,选择最后一个检查点( self.trainer.last )。 这一步确保加载的检查点是训练过程中性能最好的模型,或者至少是最新的模型。
ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
# 加载检查点。 使用 attempt_load_one_weight 函数加载选定的检查点文件。 这一步会返回两个值。 self.model :加载后的 模型实例 。 self.ckpt :检查点文件的 路径 。 这确保了训练后的模型状态被正确加载到当前模型实例中。
# def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False): -> 用于加载单个模型权重文件,并对模型进行一系列的初始化和兼容性处理。返回 处理后的模型 和 检查点数据 。 -> return model, ckpt
self.model, self.ckpt = attempt_load_one_weight(ckpt)
# 更新模型参数。 将 加载的模型的参数 ( self.model.args )更新到 self.overrides 中。 这一步确保模型的参数(如训练配置、超参数等)与训练后的状态一致。
self.overrides = self.model.args
# 获取训练指标。 从训练器的验证器( self.trainer.validator )中获取训练指标( metrics )。 如果验证器没有返回指标,则设置为 None 。 注意 :这里提到的 TODO 表示在分布式训练(DDP)中,验证器可能不会返回指标,这是一个待解决的问题。
self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP
# 返回 训练过程的指标 ,供用户使用。 这些指标通常包括验证集上的性能指标(如准确率、损失值等),用于评估模型的训练效果。
return self.metrics
# 这段代码的核心功能是执行训练过程并更新模型状态。具体步骤如下。附加 HUB 会话:将可选的 HUB 会话附加到训练器上。执行训练:调用训练器的 train 方法,开始训练过程。更新模型状态:在主进程中,选择最佳或最后一个检查点。加载检查点,更新模型实例和参数。获取训练指标。返回指标:返回训练过程的指标,供用户使用。
# 这个方法的核心功能是执行模型的训练过程,支持从头开始训练、从检查点恢复训练,以及与 Ultralytics HUB 集成。它通过以下步骤实现。检查模型是否为 PyTorch 模型:确保只有 PyTorch 模型可以执行训练操作。处理 HUB 会话:如果存在 HUB 会话,使用 HUB 的训练参数。检查 pip 更新:确保使用最新版本的依赖库。加载配置文件:如果提供了配置文件路径,加载并解析配置文件。设置默认参数:定义默认的训练参数。合并参数:将配置文件参数、默认参数和用户提供的参数合并。初始化训练器:如果用户提供了自定义训练器,则使用它;否则加载默认训练器。设置模型:如果不是从检查点恢复训练,则手动设置训练器的模型。附加 HUB 会话:将 HUB 会话附加到训练器上。执行训练:调用训练器的 train 方法,开始训练过程。更新模型和配置:在训练完成后,加载最佳检查点,更新模型和配置。返回训练指标:返回训练过程的指标。使用场景。模型训练:从头开始训练模型或从检查点恢复训练。与 HUB 集成:使用 Ultralytics HUB 的训练参数和会话。灵活配置:用户可以通过 kwargs 提供自定义的训练参数。注意事项。如果需要从检查点恢复训练,可以通过 kwargs 提供 "resume" 参数。如果使用 HUB 会话,本地训练参数将被忽略。
# 这段代码定义了 Model 类的 tune 方法,用于执行模型的超参数调优。它支持使用 Ray Tune 或自定义的 Tuner 类来进行调优。
# 定义了一个名为 tune 的方法,属于 Model 类。 参数 :
# 1.use_ray :布尔值,表示是否使用 Ray Tune 进行调优。默认为 False 。
# 2.iterations :整数,表示调优的迭代次数。默认为 10 。
# 3.*args :可变位置参数,用于传递额外的参数。
# 4.**kwargs :关键字参数,用于传递额外的配置。
def tune(
self,
use_ray=False,
iterations=10,
*args: Any,
**kwargs: Any,
):
# 对模型进行超参数调整,可选择使用 Ray Tune。
# 此方法支持两种超参数调整模式:使用 Ray Tune 或自定义调整方法。启用 Ray Tune 后,它会利用 ultralytics.utils.tuner 模块中的“run_ray_tune”函数。否则,它会使用内部“Tuner”类进行调整。该方法结合了默认、覆盖和自定义参数来配置调整过程。
# 引发:
# AssertionError:如果模型不是 PyTorch 模型。
"""
Conducts hyperparameter tuning for the model, with an option to use Ray Tune.
This method supports two modes of hyperparameter tuning: using Ray Tune or a custom tuning method.
When Ray Tune is enabled, it leverages the 'run_ray_tune' function from the ultralytics.utils.tuner module.
Otherwise, it uses the internal 'Tuner' class for tuning. The method combines default, overridden, and
custom arguments to configure the tuning process.
Args:
use_ray (bool): If True, uses Ray Tune for hyperparameter tuning. Defaults to False.
iterations (int): The number of tuning iterations to perform. Defaults to 10.
*args: Variable length argument list for additional arguments.
**kwargs: Arbitrary keyword arguments. These are combined with the model's overrides and defaults.
Returns:
(Dict): A dictionary containing the results of the hyperparameter search.
Raises:
AssertionError: If the model is not a PyTorch model.
Examples:
>>> model = YOLO("yolo11n.pt")
>>> results = model.tune(use_ray=True, iterations=20)
>>> print(results)
"""
# 调用 _check_is_pytorch_model 方法,确保当前模型是一个 PyTorch 模型。 如果模型不是 PyTorch 模型,会抛出 TypeError ,阻止非 PyTorch 模型执行调优操作。
self._check_is_pytorch_model()
# 使用 Ray Tune 进行调优。
# 如果 use_ray 为 True ,则导入 run_ray_tune 函数。
if use_ray:
from ultralytics.utils.tuner import run_ray_tune
# 调用 run_ray_tune 函数,传入当前模型实例、最大样本数( iterations )以及其他参数。 返回 Ray Tune 的调优结果。
# def run_ray_tune(model, space: dict = None, grace_period: int = 10, gpu_per_trial: int = None, max_samples: int = 10, **train_args,):
# -> 用于使用 Ray Tune 框架进行超参数优化。它支持对模型的训练参数进行自动调优,并提供了与 WandB 集成的日志记录功能。返回 超参数优化的结果 results 。 这个结果对象可以被进一步分析,例如提取最佳试验的超参数配置,或者用于后续的模型训练和评估。
# -> return results
return run_ray_tune(self, max_samples=iterations, *args, **kwargs)
# 使用自定义 Tuner 类进行调优。
# 如果 use_ray 为 False ,则导入 Tuner 类。
else:
from .tuner import Tuner
# 定义默认的调优参数 custom (当前为空)。
custom = {} # method defaults
# 合并参数。
# self.overrides :模型的覆盖参数。
# custom :默认参数。
# kwargs :用户提供的关键字参数。
# "mode": "train" :明确指定当前模式为训练模式。
# 合并后的参数中, kwargs 的优先级最高,其次是 custom 和 self.overrides 。
args = {**self.overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right
# 初始化 Tuner 实例并执行调优。 创建 Tuner 实例,传入合并后的参数 args 和回调函数 self.callbacks 。 调用 Tuner 实例的 __call__ 方法,传入当前模型实例和调优迭代次数 iterations 。 返回 调优的结果 。
return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations)
# 这个方法的核心功能是执行模型的超参数调优,支持两种方式。使用 Ray Tune:如果 use_ray=True ,调用 run_ray_tune 函数进行调优。 Ray Tune 是一个强大的超参数调优库,支持多种调优算法和资源管理。使用自定义 Tuner 类:如果 use_ray=False ,使用 Tuner 类进行调优。 Tuner 类是一个轻量级的调优工具,适用于简单的调优任务。使用场景。超参数调优:在训练模型之前,使用此方法优化超参数,以提高模型性能。灵活配置:用户可以通过 kwargs 提供自定义的调优参数。选择调优工具:根据需求选择使用 Ray Tune 或自定义 Tuner 类。注意事项。如果使用 Ray Tune,需要确保安装了 ray 库。如果使用自定义 Tuner 类,需要确保该类已正确实现调优逻辑。调优过程可能需要较多的计算资源和时间,建议在合适的硬件环境下运行。
# 这段代码定义了 Model 类的 _apply 方法,用于对模型及其子模块应用一个函数 fn 。这通常用于执行一些操作,如将模型移动到不同的设备(CPU 或 GPU)或更改模型的精度(如从 FP32 到 FP16)。
# 定义了一个名为 _apply 的方法,属于 Model 类。 参数 :
# 1.fn :一个函数,将被应用到模型及其子模块上。
# 返回 Model 类的实例(即当前模型实例),允许链式调用。
def _apply(self, fn) -> "Model":
# 将函数应用于非参数或已注册缓冲区的模型张量。
# 此方法通过额外重置预测器并更新模型覆盖中的设备来扩展父类的 _apply 方法的功能。它通常用于将模型移动到其他设备或更改其精度等操作。
# 引发:
# AssertionError:如果模型不是 PyTorch 模型。
"""
Applies a function to model tensors that are not parameters or registered buffers.
This method extends the functionality of the parent class's _apply method by additionally resetting the
predictor and updating the device in the model's overrides. It's typically used for operations like
moving the model to a different device or changing its precision.
Args:
fn (Callable): A function to be applied to the model's tensors. This is typically a method like
to(), cpu(), cuda(), half(), or float().
Returns:
(Model): The model instance with the function applied and updated attributes.
Raises:
AssertionError: If the model is not a PyTorch model.
Examples:
>>> model = Model("yolo11n.pt")
>>> model = model._apply(lambda t: t.cuda()) # Move model to GPU
"""
# 调用 _check_is_pytorch_model 方法,确保当前模型是一个 PyTorch 模型。 如果模型不是 PyTorch 模型,会抛出 TypeError ,阻止非 PyTorch 模型执行此方法。
self._check_is_pytorch_model()
# 调用父类( nn.Module )的 _apply 方法,将函数 fn 应用到模型及其子模块上。 这一步确保模型的所有参数和子模块都执行了 fn 指定的操作。 # noqa 是一个注释,用于跳过某些静态代码检查工具的警告。
self = super()._apply(fn) # noqa
# 将 self.predictor 设置为 None 。 这一步是为了 重置预测器 ,因为设备可能已经改变(例如,模型从 CPU 移动到 GPU)。 重置预测器可以避免因设备不匹配导致的错误。
self.predictor = None # reset predictor as device may have changed
# 更新 self.overrides 字典中的 "device" 键,将其值设置为当前模型所在的设备( self.device )。 这一步确保模型的设备信息是最新的,例如从 "cuda:0" 更新为 "cuda:1" 或 "cpu" 。 这里提到的 str(self.device) 是一种可能的实现方式,将设备对象转换为字符串形式(如 "cuda:0" )。
self.overrides["device"] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0'
# 返回当前模型实例( self ),允许链式调用。 这使得用户可以在调用 _apply 方法后继续对模型进行操作,例如 : model = model.to("cuda").half() 。
return self
# 这个方法的核心功能是对模型及其子模块应用一个函数 fn ,并更新模型的状态以适应可能的设备变化。具体步骤如下。检查模型是否为 PyTorch 模型:确保只有 PyTorch 模型可以执行此方法。应用函数:调用父类的 _apply 方法,将 fn 应用到模型及其子模块上。重置预测器:因为设备可能已经改变,所以重置预测器以避免错误。更新设备信息:将模型的设备信息更新到 self.overrides 中。返回模型实例:返回当前模型实例,允许链式调用。使用场景。设备迁移:将模型从 CPU 移动到 GPU 或从一个 GPU 移动到另一个 GPU。精度转换:将模型从 FP32 转换为 FP16 或其他精度。通用操作:对模型及其子模块执行任何自定义操作。注意事项。如果设备发生变化(如从 CPU 到 GPU),确保更新所有相关的状态和配置。如果使用自定义的预测器,可能需要在设备变化后重新初始化预测器。
# 这段代码定义了 Model 类的一个 @property 方法 names ,用于获取模型的类别名称( class names )。这些名称通常是一个字典,将类别索引(整数)映射到类别名称(字符串)。
@property
# 定义了一个名为 names 的 @property 方法,属于 Model 类。
# 返回一个字典,键为类别索引(整数),值为类别名称(字符串)。
def names(self) -> Dict[int, str]:
# 检索与加载的模型关联的类名。
# 如果类名在模型中定义,则此属性返回类名。它使用 ultralytics.nn.autobackend 模块中的“check_class_names”函数检查类名的有效性。如果预测器未初始化,它会在检索名称之前对其进行设置。
# 引发:
# AttributeError:如果模型或预测器没有“names”属性。
"""
Retrieves the class names associated with the loaded model.
This property returns the class names if they are defined in the model. It checks the class names for validity
using the 'check_class_names' function from the ultralytics.nn.autobackend module. If the predictor is not
initialized, it sets it up before retrieving the names.
Returns:
(Dict[int, str]): A dict of class names associated with the model.
Raises:
AttributeError: If the model or predictor does not have a 'names' attribute.
Examples:
>>> model = YOLO("yolo11n.pt")
>>> print(model.names)
{0: 'person', 1: 'bicycle', 2: 'car', ...}
"""
# 导入 check_class_names 函数,该函数用于验证和格式化类别名称。
from ultralytics.nn.autobackend import check_class_names
# 检查模型对象( self.model )是否直接包含 names 属性。
if hasattr(self.model, "names"):
# 如果存在,调用 check_class_names 函数验证和格式化这些类别名称,然后返回结果。 这一步确保返回的类别名称是有效的,并且符合预期的格式。
# # def check_class_names(names): -> 用于检查和处理类别名称( names )。它的主要功能包。将类别名称从列表转换为字典。将类别索引从字符串转换为整数,并确保类别名称为字符串格式。检查类别索引是否有效。如果类别名称是 ImageNet 的类别代码,则将其映射为人类可读的名称。返回处理后的类别名称字典。 -> return names
return check_class_names(self.model.names)
# 如果 self.predictor 未初始化(即 self.predictor 为 None ),则。
if not self.predictor: # export formats will not have predictor defined until predict() is called
# 调用 _smart_load("predictor") 方法 动态加载预测器类 。 使用 overrides=self.overrides 和 _callbacks=self.callbacks 初始化预测器实例 。
self.predictor = self._smart_load("predictor")(overrides=self.overrides, _callbacks=self.callbacks)
# 调用 self.predictor.setup_model 方法,传入当前模型实例( self.model )并设置为静默模式( verbose=False )。 这一步确保在需要时初始化预测器,因为某些导出格式(如 ONNX、TorchScript)可能在调用 predict() 方法之前不会定义预测器。
self.predictor.setup_model(model=self.model, verbose=False)
# 如果模型对象没有直接的 names 属性,则 从预测器的模型中获取类别名称 。 这一步确保即使模型对象本身没有 names 属性,也可以通过预测器获取到类别名称。
return self.predictor.model.names
# 这个 @property 方法的核心功能是获取模型的类别名称。具体步骤如下。检查模型对象是否有 names 属性:如果有,验证并格式化这些类别名称,然后返回。初始化预测器(如果需要):如果模型对象没有 names 属性,且预测器未初始化,则动态加载并初始化预测器。从预测器获取类别名称:使用预测器的模型获取类别名称并返回。使用场景。获取类别名称:在模型训练、预测或评估过程中,获取类别名称以便正确解释模型输出。支持多种模型格式:即使模型对象本身没有 names 属性,也可以通过预测器获取类别名称,确保方法的通用性。注意事项。如果模型对象已经包含 names 属性,则直接返回这些名称,避免不必要的初始化。如果需要预测器来获取类别名称,确保在调用 names 属性之前,模型已经准备好(例如,调用过 predict() 方法)。
# 这段代码定义了 Model 类的一个 @property 方法 device ,用于获取模型当前所在的设备(CPU 或 GPU)。
@property
# 定义了一个名为 device 的 @property 方法,属于 Model 类。
# 返回一个 torch.device 对象,表示模型当前所在的设备。
def device(self) -> torch.device:
# 检索分配模型参数的设备。
# 此属性确定当前存储模型参数的设备(CPU 或 GPU)。它仅适用于 nn.Module 实例的模型。
# 引发:
# AttributeError:如果模型不是 PyTorch nn.Module 实例。
"""
Retrieves the device on which the model's parameters are allocated.
This property determines the device (CPU or GPU) where the model's parameters are currently stored. It is
applicable only to models that are instances of nn.Module.
Returns:
(torch.device): The device (CPU/GPU) of the model.
Raises:
AttributeError: If the model is not a PyTorch nn.Module instance.
Examples:
>>> model = YOLO("yolo11n.pt")
>>> print(model.device)
device(type='cuda', index=0) # if CUDA is available
>>> model = model.to("cpu")
>>> print(model.device)
device(type='cpu')
"""
# 检查模型是否为 nn.Module 。
# 使用 isinstance(self.model, nn.Module) 检查 self.model 是否为 torch.nn.Module 的实例。 如果不是,返回 None ,表示模型当前没有被分配到任何设备。
# 获取模型参数的设备。
# 如果 self.model 是 nn.Module 的实例,调用 self.model.parameters() 获取模型的参数生成器。 使用 next(self.model.parameters()) 获取第一个参数(通常是模型的第一个张量)。 调用 .device 属性,获取该参数所在的设备(CPU 或 GPU)。
return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None
# 这个 @property 方法的核心功能是获取模型当前所在的设备。具体步骤如下。检查模型类型:如果模型是 torch.nn.Module 的实例,继续获取设备信息。如果不是,返回 None ,表示模型当前没有被分配到任何设备。获取第一个参数的设备:通过 self.model.parameters() 获取模型的参数生成器。使用 next() 获取第一个参数。调用 .device 属性,获取该参数所在的设备。使用场景。设备检查:在模型训练、预测或评估过程中,检查模型是否在正确的设备上(CPU 或 GPU)。动态设备管理:在代码中动态获取模型的设备信息,以便在需要时将模型移动到其他设备。注意事项。如果模型没有参数(例如,模型为空或尚未初始化), next(self.model.parameters()) 会抛出 StopIteration 异常。在这种情况下,建议在调用前检查模型是否包含参数。如果模型未分配到任何设备,返回值为 None 。
# 这段代码定义了 Model 类的一个 @property 方法 transforms ,用于获取模型的输入数据预处理变换(transforms)。
@property
# 定义了一个名为 transforms 的 @property 方法,属于 Model 类。
def transforms(self):
# 检索应用于已加载模型的输入数据的转换。
# 如果转换在模型中定义,则此属性返回转换。转换通常包括预处理步骤,如调整大小、规范化和数据增强,这些步骤在输入数据输入模型之前应用于输入数据。
"""
Retrieves the transformations applied to the input data of the loaded model.
This property returns the transformations if they are defined in the model. The transforms
typically include preprocessing steps like resizing, normalization, and data augmentation
that are applied to input data before it is fed into the model.
Returns:
(object | None): The transform object of the model if available, otherwise None.
Examples:
>>> model = YOLO("yolo11n.pt")
>>> transforms = model.transforms
>>> if transforms:
... print(f"Model transforms: {transforms}")
... else:
... print("No transforms defined for this model.")
"""
# 检查模型是否具有 transforms 属性。
# 使用 hasattr(self.model, "transforms") 检查 self.model 是否有一个名为 transforms 的属性。
# 如果存在,返回 self.model.transforms ,即模型的输入数据预处理变换。
# 如果不存在,返回 None ,表示模型没有定义输入数据预处理变换。
return self.model.transforms if hasattr(self.model, "transforms") else None
# 这个 @property 方法的核心功能是获取模型的输入数据预处理变换。具体步骤如下。检查模型是否具有 transforms 属性:如果模型定义了 transforms ,返回该属性。如果模型没有定义 transforms ,返回 None 。使用场景。数据预处理:在模型训练、预测或评估过程中,获取模型的输入数据预处理变换,以便正确处理输入数据。模型调试:检查模型是否定义了预处理变换,确保数据输入的一致性。注意事项。如果模型没有定义 transforms ,返回值为 None 。在这种情况下,可能需要手动定义预处理步骤。如果模型定义了 transforms ,但其内容为空或无效,需要进一步检查模型的实现细节。
# 这段代码定义了 Model 类的 add_callback 方法,用于向模型添加自定义回调函数。回调函数通常用于在模型训练、预测或其他操作中执行特定的任务,例如日志记录、性能监控或自定义行为。
# 定义了一个名为 add_callback 的方法,属于 Model 类。 参数 :
# 1.event : str ,回调事件的名称,例如 "on_train_start" 、 "on_epoch_end" 等。
# 2.func :回调函数,将在指定事件发生时被调用。
# 返回值 None ,因为该方法的主要目的是修改内部状态,而不是返回值。
def add_callback(self, event: str, func) -> None:
# 为指定事件添加回调函数。
# 此方法允许注册在模型操作(例如训练或推理)期间触发特定事件的自定义回调函数。回调提供了一种在模型生命周期的各个阶段扩展和自定义模型行为的方法。
# 引发:
# ValueError:如果事件名称无法识别或无效。
"""
Adds a callback function for a specified event.
This method allows registering custom callback functions that are triggered on specific events during
model operations such as training or inference. Callbacks provide a way to extend and customize the
behavior of the model at various stages of its lifecycle.
Args:
event (str): The name of the event to attach the callback to. Must be a valid event name recognized
by the Ultralytics framework.
func (Callable): The callback function to be registered. This function will be called when the
specified event occurs.
Raises:
ValueError: If the event name is not recognized or is invalid.
Examples:
>>> def on_train_start(trainer):
... print("Training is starting!")
>>> model = YOLO("yolo11n.pt")
>>> model.add_callback("on_train_start", on_train_start)
>>> model.train(data="coco8.yaml", epochs=1)
"""
# 添加回调函数。
# self.callbacks 是一个字典,键为事件名称(如 "on_train_start" ),值为回调函数列表。
# 使用 event 作为键,将回调函数 func 添加到对应的列表中。
# 如果 event 不存在于 self.callbacks 中,Python 会自动创建一个空列表,并将 func 添加到该列表中。
self.callbacks[event].append(func)
# 这个方法的核心功能是向模型添加自定义回调函数。具体步骤如下。检查事件名称:使用 event 作为键,查找 self.callbacks 中对应的回调函数列表。添加回调函数:将回调函数 func 添加到对应的列表中。如果该事件的回调列表不存在,自动创建一个空列表并添加回调函数。使用场景。自定义行为:在模型训练、预测或其他操作中执行特定任务,例如:在训练开始时记录日志。在每个 epoch 结束时保存模型权重。在预测时执行自定义的后处理操作。扩展模型功能:通过添加回调函数,无需修改模型代码即可扩展其功能。注意事项。事件名称:确保使用的事件名称与模型支持的事件一致。例如,常见的事件包括 "on_train_start" 、 "on_epoch_end" 、 "on_predict" 等。回调函数:回调函数应接受适当的参数,具体取决于事件类型。例如,训练事件的回调函数可能需要接受训练器对象或当前 epoch 作为参数。线程安全:如果模型在多线程环境中运行,确保回调函数的线程安全性。
# 这段代码定义了 Model 类的 clear_callback 方法,用于清除指定事件的所有回调函数。
# 定义了一个名为 clear_callback 的方法,属于 Model 类。 参数 :
# 1.event : str ,要清除回调函数的事件名称,例如 "on_train_start" 、 "on_epoch_end" 等。
# 返回值 None ,因为该方法的主要目的是修改内部状态,而不是返回值。
def clear_callback(self, event: str) -> None:
# 清除为指定事件注册的所有回调函数。
# 此方法删除与给定事件关联的所有自定义和默认回调函数。它将指定事件的回调列表重置为空列表,从而有效地删除该事件的所有已注册回调。
# 注意事项:
# - 此方法会影响用户添加的自定义回调和 Ultralytics 框架提供的默认回调。
# - 调用此方法后,将不会为指定事件执行任何回调,直到添加新的回调。
# - 请谨慎使用,因为它会删除所有回调,包括某些操作正常运行可能需要的基本回调。
"""
Clears all callback functions registered for a specified event.
This method removes all custom and default callback functions associated with the given event.
It resets the callback list for the specified event to an empty list, effectively removing all
registered callbacks for that event.
Args:
event (str): The name of the event for which to clear the callbacks. This should be a valid event name
recognized by the Ultralytics callback system.
Examples:
>>> model = YOLO("yolo11n.pt")
>>> model.add_callback("on_train_start", lambda: print("Training started"))
>>> model.clear_callback("on_train_start")
>>> # All callbacks for 'on_train_start' are now removed
Notes:
- This method affects both custom callbacks added by the user and default callbacks
provided by the Ultralytics framework.
- After calling this method, no callbacks will be executed for the specified event
until new ones are added.
- Use with caution as it removes all callbacks, including essential ones that might
be required for proper functioning of certain operations.
"""
# 清除回调函数。
# 使用 event 作为键,直接将 self.callbacks 中对应的回调函数列表设置为空列表 [] 。
# 这一步会移除该事件的所有回调函数,确保在后续操作中不会触发任何回调。
self.callbacks[event] = []
# 这个方法的核心功能是清除指定事件的所有回调函数。具体步骤如下。定位事件:使用 event 作为键,查找 self.callbacks 中对应的回调函数列表。清空回调列表:将该事件的回调函数列表设置为空列表 [] ,移除所有回调函数。使用场景。重置回调:在需要重置某个事件的回调函数时,使用此方法清除所有已注册的回调。调试和测试:在调试或测试过程中,清除回调函数以避免不必要的干扰。动态管理回调:在模型运行过程中,根据需要动态添加或清除回调函数。注意事项。事件名称:确保提供的事件名称是有效的,且存在于 self.callbacks 中。如果事件名称不存在,Python 会自动创建一个空列表,但可能不是预期的行为。影响范围:清除回调函数后,该事件将不再触发任何回调,直到重新添加回调函数。
# 这段代码定义了 Model 类的 reset_callbacks 方法,用于将所有回调函数重置为默认值。
# 定义了一个名为 reset_callbacks 的方法,属于 Model 类。
# 返回值 None ,因为该方法的主要目的是修改内部状态,而不是返回值。
def reset_callbacks(self) -> None:
# 将所有回调重置为其默认函数。
# 此方法恢复所有事件的默认回调函数,删除之前添加的任何自定义回调。它遍历所有默认回调事件,并用默认回调替换当前回调。
# 默认回调在“callbacks.default_callbacks”字典中定义,其中包含模型生命周期中各种事件的预定义函数,例如 on_train_start、on_epoch_end 等。
# 当您想要在进行自定义修改后恢复到原始回调集时,此方法很有用,可确保在不同运行或实验中的行为一致。
"""
Resets all callbacks to their default functions.
This method reinstates the default callback functions for all events, removing any custom callbacks that were
previously added. It iterates through all default callback events and replaces the current callbacks with the
default ones.
The default callbacks are defined in the 'callbacks.default_callbacks' dictionary, which contains predefined
functions for various events in the model's lifecycle, such as on_train_start, on_epoch_end, etc.
This method is useful when you want to revert to the original set of callbacks after making custom
modifications, ensuring consistent behavior across different runs or experiments.
Examples:
>>> model = YOLO("yolo11n.pt")
>>> model.add_callback("on_train_start", custom_function)
>>> model.reset_callbacks()
# All callbacks are now reset to their default functions
"""
# 遍历 callbacks.default_callbacks 的所有键(即所有支持的事件名称)。 callbacks.default_callbacks 是一个字典,存储了每个事件的默认回调函数。
for event in callbacks.default_callbacks.keys():
# 对于每个事件 event 。
# 从 callbacks.default_callbacks 中获取该事件的默认回调函数列表。
# 将 self.callbacks[event] 设置为该事件的默认回调函数列表的第一个元素(通常是一个默认的回调函数)。
# 这一步确保每个事件的回调函数被重置为默认值。
self.callbacks[event] = [callbacks.default_callbacks[event][0]]
# 这个方法的核心功能是将所有回调函数重置为默认值。具体步骤如下。遍历所有支持的事件:使用 callbacks.default_callbacks.keys() 获取所有支持的事件名称。重置回调函数:对于每个事件,将 self.callbacks[event] 设置为该事件的默认回调函数列表的第一个元素。使用场景。重置回调:在模型训练、预测或其他操作之前,将所有回调函数重置为默认值,以确保一致的行为。动态管理回调:在模型运行过程中,根据需要动态重置回调函数。调试和测试:在调试或测试过程中,重置回调函数以避免不必要的干扰。注意事项。默认回调函数: callbacks.default_callbacks 应该是一个字典,其中每个键对应一个事件名称,值是一个包含默认回调函数的列表。影响范围:重置回调函数后,所有事件的回调将被设置为默认值,直到重新添加自定义回调函数。
# 这段代码定义了 Model 类的一个 @staticmethod 方法 _reset_ckpt_args ,用于过滤和保留加载 PyTorch 模型时需要的关键参数。
@staticmethod
# 定义了一个名为 _reset_ckpt_args 的静态方法,属于 Model 类。 参数 :
# 1.args : dict ,输入的参数字典,通常是从模型检查点(checkpoint)加载的参数。
# 返回一个过滤后的字典,仅包含需要保留的关键参数。
def _reset_ckpt_args(args: dict) -> dict:
# 加载 PyTorch 模型检查点时重置特定参数。
# 此静态方法过滤输入参数字典以仅保留一组对模型加载很重要的特定键。它用于确保从检查点加载模型时仅保留相关参数,丢弃任何不必要或可能冲突的设置。
"""
Resets specific arguments when loading a PyTorch model checkpoint.
This static method filters the input arguments dictionary to retain only a specific set of keys that are
considered important for model loading. It's used to ensure that only relevant arguments are preserved
when loading a model from a checkpoint, discarding any unnecessary or potentially conflicting settings.
Args:
args (dict): A dictionary containing various model arguments and settings.
Returns:
(dict): A new dictionary containing only the specified include keys from the input arguments.
Examples:
>>> original_args = {"imgsz": 640, "data": "coco.yaml", "task": "detect", "batch": 16, "epochs": 100}
>>> reset_args = Model._reset_ckpt_args(original_args)
>>> print(reset_args)
{'imgsz': 640, 'data': 'coco.yaml', 'task': 'detect'}
"""
# 定义一个集合 include ,包含 需要保留的关键参数名称 。
# "imgsz" :输入图像尺寸。
# "data" :数据集配置路径。
# "task" :任务类型(如检测、分割等)。
# "single_cls" :是否为单类别任务。
include = {"imgsz", "data", "task", "single_cls"} # only remember these arguments when loading a PyTorch model
# 使用字典推导式过滤输入字典 args 。
# 遍历 args 的键值对 (k, v) 。
# 仅保留键 k 在 include 集合中的键值对。
# 返回 过滤后的字典 ,仅包含需要保留的关键参数。
return {k: v for k, v in args.items() if k in include}
# 这个静态方法的核心功能是过滤检查点中的参数,只保留特定的关键参数。具体步骤如下。定义需要保留的参数键:使用集合 include 定义需要保留的参数键。过滤参数:使用字典推导式,只保留键在 include 集合中的参数键值对。返回过滤后的参数字典。使用场景。加载模型时过滤参数:在加载 PyTorch 模型时,只保留关键参数,忽略其他可能不必要的参数。简化模型配置:确保模型加载时只使用必要的参数,避免潜在的冲突或错误。注意事项。参数选择: include 集合中的参数是根据实际需求选择的,可以根据具体场景调整。返回值:返回的是一个过滤后的参数字典,只包含需要保留的参数键值对。
# def __getattr__(self, attr):
# """Raises error if object has no requested attribute."""
# name = self.__class__.__name__
# raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
# 这段代码定义了 Model 类的 _smart_load 方法,用于动态加载与当前任务相关的模块或类。
# 定义了一个名为 _smart_load 的方法,属于 Model 类。 参数 :
# 1.key : str ,要加载的模块或类的键(例如 "model" 、 "trainer" 、 "predictor" 等)。
# 返回与当前任务相关的模块或类。
def _smart_load(self, key: str):
# 根据模型任务加载适当的模块。
# 此方法根据模型的当前任务和提供的键动态选择并返回正确的模块(模型、训练器、验证器或预测器)。它使用 task_map 属性来确定要加载的正确模块。
# 引发:
# NotImplementedError:如果当前任务不支持指定的键。
# 注意事项:
# - 此方法通常由 Model 类的其他方法内部使用。
# - task_map 属性应正确使用每个任务的正确映射进行初始化。
"""
Loads the appropriate module based on the model task.
This method dynamically selects and returns the correct module (model, trainer, validator, or predictor)
based on the current task of the model and the provided key. It uses the task_map attribute to determine
the correct module to load.
Args:
key (str): The type of module to load. Must be one of 'model', 'trainer', 'validator', or 'predictor'.
Returns:
(object): The loaded module corresponding to the specified key and current task.
Raises:
NotImplementedError: If the specified key is not supported for the current task.
Examples:
>>> model = Model(task="detect")
>>> predictor = model._smart_load("predictor")
>>> trainer = model._smart_load("trainer")
Notes:
- This method is typically used internally by other methods of the Model class.
- The task_map attribute should be properly initialized with the correct mappings for each task.
"""
# 尝试动态加载模块。
try:
# self.task_map 是一个字典,存储了不同任务(如 "detect" 、 "segment" 等)对应的模块或类。
# self.task 是当前模型的任务类型。
# 使用 self.task_map[self.task][key] 尝试获取与当前任务和键 key 对应的模块或类。
# 如果成功, 返回该模块或类 。
return self.task_map[self.task][key]
# 捕获异常。
# 如果在尝试加载模块时发生任何异常(例如, self.task 或 key 不存在于 self.task_map 中),捕获该异常并存储到变量 e 中。
except Exception as e:
# 获取当前类的名称。 使用 self.__class__.__name__ 获取当前类的名称(例如 "Model" )。
name = self.__class__.__name__
# 获取 调用该方法的函数名称 。 使用 inspect.stack() 获取当前调用栈信息。 inspect.stack()[1][3] 获取调用该方法的函数名称(例如 "train" 、 "predict" 等)。 这一步用于在错误消息中提供上下文信息,帮助用户理解问题发生的位置。
mode = inspect.stack()[1][3] # get the function name.
# 抛出自定义错误。
raise NotImplementedError(
# 使用 NotImplementedError 抛出一个自定义的错误消息,提示用户当前任务不支持该模式。 错误消息中包含了当前类的名称、调用的函数名称和任务类型,以便用户快速定位问题。
emojis(f"WARNING ⚠️ '{name}' model does not support '{mode}' mode for '{self.task}' task yet.") # 警告⚠️'{name}' 模型尚不支持'{self.task}' 任务的'{mode}' 模式。
# 使用 from e 将原始异常作为上下文信息附加到自定义错误中,便于调试。
) from e
# 这个方法的核心功能是动态加载与当前任务相关的模块或类。具体步骤如下。尝试加载模块:使用 self.task_map[self.task][key] 尝试获取与当前任务和键 key 对应的模块或类。捕获异常:如果加载失败,捕获异常并获取当前类的名称和调用的函数名称。抛出自定义错误:抛出 NotImplementedError ,提示用户当前任务不支持该模式,并提供详细的上下文信息。使用场景。动态加载模块:在模型训练、预测或其他操作中,根据当前任务动态加载所需的模块或类。错误处理:在加载失败时,提供详细的错误信息,帮助用户快速定位问题。注意事项。任务映射: self.task_map 应该是一个字典,存储了不同任务对应的模块或类。如果任务映射未正确初始化,可能会导致加载失败。上下文信息:通过 inspect.stack() 获取调用函数的名称,提供更详细的错误信息。错误类型:使用 NotImplementedError 表示某些功能尚未实现或不支持,便于用户理解和处理。
# 这段代码定义了 Model 类的一个 @property 方法 task_map ,用于获取模型的任务映射(task map)。任务映射是一个字典,它将不同的任务类型(如 "detect" 、 "segment" 等)映射到相应的模块或类。
@property
# 定义了一个名为 task_map 的 @property 方法,属于 Model 类。
# 返回一个字典,表示任务映射。
def task_map(self) -> dict:
# 提供从模型任务到不同模式的相应类的映射。
# 此属性方法返回一个字典,该字典将每个支持的任务(例如,检测、分段、分类)映射到嵌套字典。嵌套字典包含不同操作模式(模型、训练器、验证器、预测器)到其各自类实现的映射。
# 映射允许根据模型的任务和所需的操作模式动态加载适当的类。这有助于在 Ultralytics 框架内处理各种任务和模式的灵活且可扩展的架构。
# 注意:
# 此方法的实际实现可能因 Ultralytics 框架支持的特定任务和类而异。文档字符串提供了预期行为和结构的一般描述。
"""
Provides a mapping from model tasks to corresponding classes for different modes.
This property method returns a dictionary that maps each supported task (e.g., detect, segment, classify)
to a nested dictionary. The nested dictionary contains mappings for different operational modes
(model, trainer, validator, predictor) to their respective class implementations.
The mapping allows for dynamic loading of appropriate classes based on the model's task and the
desired operational mode. This facilitates a flexible and extensible architecture for handling
various tasks and modes within the Ultralytics framework.
Returns:
(Dict[str, Dict[str, Any]]): A dictionary where keys are task names (str) and values are
nested dictionaries. Each nested dictionary has keys 'model', 'trainer', 'validator', and
'predictor', mapping to their respective class implementations.
Examples:
>>> model = Model()
>>> task_map = model.task_map
>>> detect_class_map = task_map["detect"]
>>> segment_class_map = task_map["segment"]
Note:
The actual implementation of this method may vary depending on the specific tasks and
classes supported by the Ultralytics framework. The docstring provides a general
description of the expected behavior and structure.
"""
# 抛出未实现错误。
# 当尝试访问 task_map 属性时,抛出一个 NotImplementedError 异常。
# 错误消息提示用户需要为模型提供任务映射。
# 这种设计模式通常用于抽象类或基类,要求子类必须实现该属性。
raise NotImplementedError("Please provide task map for your model!") # 请提供您的模型的任务图!
# 这个 @property 方法的核心功能是提供模型的任务映射。具体步骤如下。抛出未实现错误:当访问 task_map 属性时,抛出 NotImplementedError ,提示用户需要实现该属性。这种设计模式确保子类必须提供任务映射,以支持动态加载模块或类。使用场景。抽象类或基类:在定义抽象类或基类时,使用这种模式强制子类实现特定的属性或方法。任务支持:确保模型支持多种任务(如检测、分割等),并通过任务映射动态加载相应的模块或类。注意事项。子类实现:子类必须实现 task_map 属性,否则会抛出 NotImplementedError 。任务映射格式:任务映射通常是一个字典,键为任务类型(如 "detect" ),值为对应的模块或类。
# 这段代码定义了 Model 类的 eval 方法,用于将模型设置为评估模式。
# 定义了一个名为 eval 的方法,属于 Model 类。 该方法没有参数,也没有返回值,其主要目的是将模型设置为评估模式。
def eval(self):
# 将模型设置为评估模式。
# 此方法将模型的模式更改为评估,这会影响在训练和评估期间表现不同的层(如 dropout 和 batch normalization)。
"""
Sets the model to evaluation mode.
This method changes the model's mode to evaluation, which affects layers like dropout and batch normalization
that behave differently during training and evaluation.
Returns:
(Model): The model instance with evaluation mode set.
Examples:
>> model = YOLO("yolo11n.pt")
>> model.eval()
"""
# 调用底层模型对象( self.model )的 eval 方法,将模型设置为 评估模式 。
# 在 PyTorch 中, eval 方法用于将模型的某些层(如 Dropout 和 BatchNorm )切换到评估模式,以确保这些层在推理时的行为与训练时不同。
# Dropout 层在评估模式下不会随机丢弃神经元。
# BatchNorm 层在评估模式下会使用全局统计量,而不是小批量统计量。
self.model.eval()
# 返回当前模型实例( self ),允许链式调用。 这使得用户可以在调用 eval 方法后继续对模型进行操作,例如 : model.eval().predict(source="image.jpg") 。
return self
# 这个方法的核心功能是将模型设置为评估模式。具体步骤如下。调用底层模型的 eval 方法:将模型的某些层切换到评估模式,以确保推理时的行为正确。返回模型实例:返回当前模型实例,允许链式调用。使用场景。模型评估:在模型训练完成后,使用 eval 方法将模型切换到评估模式,以进行推理或评估。推理优化:确保模型在推理时的行为与训练时不同,以提高推理性能。注意事项。在调用 eval 方法后,模型将保持在评估模式,直到显式调用 train 方法切换回训练模式。如果模型中包含自定义层或模块,确保这些层也支持 eval 和 train 方法。
# 这段代码定义了 Model 类的 __getattr__ 方法,用于处理对实例属性的访问。 __getattr__ 是一个特殊方法,当尝试访问一个实例的属性时,如果该属性不存在于实例的字典中,Python 会调用这个方法。
# 定义了一个名为 __getattr__ 的方法,属于 Model 类。 参数 :
# 1.name :尝试访问的属性名称。
def __getattr__(self, name):
# 允许直接通过 Model 类访问模型属性。
# 此方法提供了一种通过 Model 类实例直接访问底层模型属性的方法。它首先检查所请求的属性是否为“model”,如果是,则从模块字典中返回模型。否则,它将属性查找委托给底层模型。
# 引发:
# AttributeError:如果所请求的属性在模型中不存在。
"""
Enables accessing model attributes directly through the Model class.
This method provides a way to access attributes of the underlying model directly through the Model class
instance. It first checks if the requested attribute is 'model', in which case it returns the model from
the module dictionary. Otherwise, it delegates the attribute lookup to the underlying model.
Args:
name (str): The name of the attribute to retrieve.
Returns:
(Any): The requested attribute value.
Raises:
AttributeError: If the requested attribute does not exist in the model.
Examples:
>>> model = YOLO("yolo11n.pt")
>>> print(model.stride)
>>> print(model.task)
"""
# 处理属性访问。
# 如果尝试访问的属性名称是 "model" ,则返回 self._modules["model"] 。
# 这里假设 self._modules 是一个字典,存储了模型的模块或子模块。
# 这种设计可能用于支持更复杂的模块化结构,例如在模型中嵌套多个子模型。
# 如果尝试访问的属性名称不是 "model" ,则调用 getattr(self.model, name) 。
# getattr 是 Python 的内置函数,用于动态访问对象的属性。
# 这一步会尝试从底层模型对象( self.model )中获取指定的属性。
# 如果底层模型对象中存在该属性,则返回其值;否则, getattr 会抛出 AttributeError 。
return self._modules["model"] if name == "model" else getattr(self.model, name)
# 这个方法的核心功能是处理对实例属性的访问,特别是当属性不存在于实例的字典中时。具体步骤如下。检查属性名称:如果属性名称是 "model" ,返回 self._modules["model"] 。动态访问底层模型的属性:如果属性名称不是 "model" ,使用 getattr 从底层模型对象( self.model )中获取属性。使用场景。属性代理:允许用户通过 Model 类的实例直接访问底层模型对象的属性,而无需显式访问底层模型对象。模块化设计:支持更复杂的模块化结构,例如在模型中嵌套多个子模型。注意事项。如果底层模型对象中不存在指定的属性, getattr 会抛出 AttributeError 。这种设计假设 self.model 是一个有效的对象,并且支持动态属性访问。如果需要更复杂的属性访问逻辑,可以在 __getattr__ 方法中添加额外的检查或处理逻辑。
# Model 类是一个功能丰富的框架,用于封装和管理深度学习模型的训练、预测、导出和评估等操作。它通过灵活的接口设计,支持多种任务类型(如目标检测、分割、分类等),并允许用户根据需求动态加载和切换模型组件。该类提供了从模型初始化、训练、超参数调优到性能评估和部署的全流程支持,同时通过回调机制和任务映射增强了可扩展性和自定义能力。此外, Model 类还集成了对 Ultralytics HUB 和 Triton Server 的支持,使得模型管理更加便捷。通过精心设计的属性和方法, Model 类确保了模型在不同阶段的高效管理和灵活操作,适用于从研究到生产部署的各种场景。