大家好:
YOLOv8项目的模型文件存在于如下路径:
./ultralytics-main/ultralytics/cfg/models
以目标检测模型为例,YOLOv8如何从model = YOLO('yolov8n.yaml')加载模型。
根据YOLOv8项目的官方示例,可以用如下代码开展目标检测模型的训练:
from ultralytics import YOLO
model = YOLO('yolov8n.yaml')
model = YOLO('yolov8n.pt')
model = YOLO('yolov8n.yaml').load('yolov8n.pt')
results = model.train(data='coco128.yaml', epochs=100, imgsz=640)
YOLO类存在于./ultralytics-main/ultralytics/models/yolo/model.py:
class YOLO(Model):
"""YOLO (You Only Look Once) object detection model."""
def __init__(self, model="yolov8n.pt", task=None, verbose=False):
"""Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'."""
path = Path(model)
#不会执行if, 而是执行else
if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model
new_instance = YOLOWorld(path)
self.__class__ = type(new_instance)
self.__dict__ = new_instance.__dict__
else:
# Continue with default YOLO initialization
super().__init__(model=model, task=task, verbose=verbose)
@property
def task_map(self):
"""Map head to model, trainer, validator, and predictor classes."""
return {
"classify": {
"model": ClassificationModel,
"trainer": yolo.classify.ClassificationTrainer,
"validator": yolo.classify.ClassificationValidator,
"predictor": yolo.classify.ClassificationPredictor,
},
"detect": {
"model": DetectionModel,
"trainer": yolo.detect.DetectionTrainer,
"validator": yolo.detect.DetectionValidator,
"predictor": yolo.detect.DetectionPredictor,
},
"segment": {
"model": SegmentationModel,
"trainer": yolo.segment.SegmentationTrainer,
"validator": yolo.segment.SegmentationValidator,
"predictor": yolo.segment.SegmentationPredictor,
},
"pose": {
"model": PoseModel,
"trainer": yolo.pose.PoseTrainer,
"validator": yolo.pose.PoseValidator,
"predictor": yolo.pose.PosePredictor,
},
"obb": {
"model": OBBModel,
"trainer": yolo.obb.OBBTrainer,
"validator": yolo.obb.OBBValidator,
"predictor": yolo.obb.OBBPredictor,
},
}
对于YOLO类的__init__()函数中,首先判断 执行哪个程序段。
from pathlib import Path
model = 'yolov8n.yaml'
path = Path(model)
print(path.stem)
print(path.suffix)
'''
path.stem:yolov8n
path.suffix:.yaml
'''
path.stem与path.suffix运行结果如上,因此不会执行第一个if程序段,转而执行else程序段,即YOLO父类Model类的__init__()方法,Model类初始化方法如下所示:
class Model(nn.Module):
def __init__(
self,
model: Union[str, Path] = "yolov8n.pt",
task: str = None,
verbose: bool = False,
) -> None:
super().__init__()
self.callbacks = callbacks.get_default_callbacks()
self.predictor = None # reuse predictor
self.model = None # model object
self.trainer = None # trainer object
self.ckpt = None # if loaded from *.pt
self.cfg = None # if loaded from *.yaml
self.ckpt_path = None
self.overrides = {} # overrides for trainer object
self.metrics = None # validation/training metrics
self.session = None # HUB session
self.task = task # task type
self.model_name = model = str(model).strip() # strip spaces
# Check if Ultralytics HUB model from https://hub.ultralytics.com
if self.is_hub_model(model):
# Fetch model from HUB
checks.check_requirements("hub-sdk>0.0.2")
self.session = self._get_hub_session(model)
model = self.session.model_file
# Check if Triton Server model
elif self.is_triton_model(model):
self.model = model
self.task = task
return
# Load or create new YOLO model
model = checks.check_model_file_from_stem(model) # add suffix, i.e. yolov8n -> yolov8n.pt
if Path(model).suffix in (".yaml", ".yml"):
self._new(model, task=task, verbose=verbose)
else:
self._load(model, task=task)
self.model_name = model
此类的model参数仍然等于传过来的yolov8n.yaml,在Model类的初始化方法中,首先执行:
if self.is_hub_model(model):
is_hub_model(model)函数如下:
def is_hub_model(model: str) -> bool:
"""Check if the provided model is a HUB model."""
return any(
(
model.startswith(f"{HUB_WEB_ROOT}/models/"), # i.e. https://hub.ultralytics.com/models/MODEL_ID
[len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODELID
len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"), # MODELID
)
)
HUB_WEB_ROOT参数如下:
HUB_WEB_ROOT = os.environ.get("ULTRALYTICS_HUB_WEB", "https://hub.ultralytics.com")
因此,该函数返回值False,再检查elif self.is_triton_model(model),is_triton_model(model)函数如下:
def is_triton_model(model: str) -> bool:
"""Is model a Triton Server URL string, i.e. <scheme>://<netloc>/<endpoint>/<task_name>"""
from urllib.parse import urlsplit
url = urlsplit(model)
return url.netloc and url.path and url.scheme in {"http", "grpc"}
执行如下语句:
from urllib.parse import urlsplit
path = 'yolov8n.yaml'
url = urlsplit(path)
print(f'url:{url}')
'''
url:SplitResult(scheme='', netloc='', path='yolov8n.yaml', query='', fragment='')
'''
因此,url.netloc与url.scheme均为None,因此返回值为False。
接着执行model = checks.check_model_file_from_stem(model),checks.check_model_file_from_stem(model)函数如下:
def check_model_file_from_stem(model="yolov8n"):
"""Return a model filename from a valid model stem."""
if model and not Path(model).suffix and Path(model).stem in downloads.GITHUB_ASSETS_STEMS:
return Path(model).with_suffix(".pt") # add suffix, i.e. yolov8n -> yolov8n.pt
else:
return model
由于Path(model).suffix不为空,因此直接执行return model,注意此时的参数model仍然为'yolov8n.yaml'。
接着判断:
if Path(model).suffix in (".yaml", ".yml"):
显然上述判断为真,则执行下述语句:
self._new(model, task=task, verbose=verbose)
该函数如下:
def _new(self, cfg: str, task=None, model=None, verbose=False) -> None:
"""
Initializes a new model and infers the task type from the model definitions.
Args:
cfg (str): model configuration file
task (str | None): model task
model (BaseModel): Customized model.
verbose (bool): display model info on load
"""
cfg_dict = yaml_model_load(cfg)
self.cfg = cfg
self.task = task or guess_model_task(cfg_dict)
self.model = (model or self._smart_load("model"))(cfg_dict, verbose=verbose and RANK == -1) # build model
self.overrides["model"] = self.cfg
self.overrides["task"] = self.task
# Below added to allow export from YAMLs
self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args)
self.model.task = self.task
首先执行:
cfg_dict = yaml_model_load(cfg)
yaml_model_load(cfg)函数如下:
def yaml_model_load(path):
"""Load a YOLOv8 model from a YAML file."""
import re
path = Path(path)
if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)):
new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem)
LOGGER.warning(f"WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.")
path = path.with_name(new_stem + path.suffix)
unified_path = re.sub(r"(\d+)([nslmx])(.+)?$", r"\1\3", str(path)) # i.e. yolov8x.yaml -> yolov8.yaml
yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path)
d = yaml_load(yaml_file) # model dict
d["scale"] = guess_model_scale(path)
d["yaml_file"] = str(path)
return d
第一个if语句判断是False,因为 f''yolov{d}{x}6...''中有个6,因此判断为False,如果没有6则判断为True。
因此,会运行re.sub()函数,且结果如下:
unified_path = re.sub(r"(\d+)([nslmx])(.+)?$", r"\1\3", str(path)) # i.e. yolov8x.yaml -> yolov8.yaml
print(f'unified_path:{unified_path}')
'''
unified_path:yolov8.yaml
'''
然后运行如下函数:
yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path)
check_yaml()函数如下:
def check_yaml(file, suffix=(".yaml", ".yml"), hard=True):
"""Search/download YAML file (if necessary) and return path, checking suffix."""
return check_file(file, suffix, hard=hard)
check_file() 函数如下:
def check_file(file, suffix="", download=True, hard=True):
"""Search/download file (if necessary) and return path."""
check_suffix(file, suffix) # optional
file = str(file).strip() # convert to string and strip spaces
file = check_yolov5u_filename(file) # yolov5n -> yolov5nu
if (
not file
or ("://" not in file and Path(file).exists()) # '://' check required in Windows Python<3.10
or file.lower().startswith("grpc://")
): # file exists or gRPC Triton images
return file
elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): # download
url = file # warning: Pathlib turns :// -> :/
file = url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth
if Path(file).exists():
LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists
else:
downloads.safe_download(url=url, file=file, unzip=False)
return file
else: # search
files = glob.glob(str(ROOT / "cfg" / "**" / file), recursive=True) # find file
if not files and hard:
raise FileNotFoundError(f"'{file}' does not exist")
elif len(files) > 1 and hard:
raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}")
return files[0] if len(files) else [] # return file
check_suffix()函数可以忽略,不影响后续代码执行顺序,因此暂时不处理。注意此时的file参数值为'yolov8.yaml'。
接着执行:
file = check_yolov5u_filename(file) # yolov5n -> yolov5nu
check_yolov5u_filename()函数如下:
def check_yolov5u_filename(file: str, verbose: bool = True):
"""Replace legacy YOLOv5 filenames with updated YOLOv5u filenames."""
if "yolov3" in file or "yolov5" in file:
if "u.yaml" in file:
file = file.replace("u.yaml", ".yaml") # i.e. yolov5nu.yaml -> yolov5n.yaml
elif ".pt" in file and "u" not in file:
original_file = file
file = re.sub(r"(.*yolov5([nsmlx]))\.pt", "\\1u.pt", file) # i.e. yolov5n.pt -> yolov5nu.pt
file = re.sub(r"(.*yolov5([nsmlx])6)\.pt", "\\1u.pt", file) # i.e. yolov5n6.pt -> yolov5n6u.pt
file = re.sub(r"(.*yolov3(|-tiny|-spp))\.pt", "\\1u.pt", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt
if file != original_file and verbose:
LOGGER.info(
f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are "
f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs "
f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n"
)
return file
该函数显然if语句为False,直接return file即可。
接着回到check_file()函数中,执行如下语句:
if (
not file
or ("://" not in file and Path(file).exists()) # '://' check required in Windows Python<3.10
or file.lower().startswith("grpc://")
): # file exists or gRPC Triton images
return file
elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): # download
url = file # warning: Pathlib turns :// -> :/
file = url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth
if Path(file).exists():
LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists
else:
downloads.safe_download(url=url, file=file, unzip=False)
return file
由于file此时取值为'yolov8.yaml',因此not file、Path(file).exists()以及file.lower().startswith(''grpc://'')均为False,且elif语句中的 file.lower().startswith((''...''))因为False,最终,会直接运行如下语句:
else: # search
files = glob.glob(str(ROOT / "cfg" / "**" / file), recursive=True) # find file
if not files and hard:
raise FileNotFoundError(f"'{file}' does not exist")
elif len(files) > 1 and hard:
raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}")
return files[0] if len(files) else [] # return file
第一语句:
files = glob.glob(str(ROOT / "cfg" / "**" / file), recursive=True) # find file
'''
ROOT来自于./ultralytics-main/ultralytics/utils/__init__.py,如下:
FILE = Path(__file__).resolve()
ROOT = FILE.parents[1] # YOLO
'''
返回的是一个列表:即yolov8.yaml的绝对文件路径
['D:\\ultralytics-main\\ultralytics\\cfg\\models\\v8\\yolov8.yaml']
可知,直接执行return files[0] if len(files) else[]。
则直接返回到yaml_model_load()函数:
def yaml_model_load(path):
"""Load a YOLOv8 model from a YAML file."""
import re
path = Path(path)
if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)):
new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem)
LOGGER.warning(f"WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.")
path = path.with_name(new_stem + path.suffix)
unified_path = re.sub(r"(\d+)([nslmx])(.+)?$", r"\1\3", str(path)) # i.e. yolov8x.yaml -> yolov8.yaml
#在这里接收到路径
yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path)
d = yaml_load(yaml_file) # model dict
d["scale"] = guess_model_scale(path)
d["yaml_file"] = str(path)
return d
其中,yaml_file用来接收yolov8.yaml的路径,d = yaml_load(yaml_file)则用来加载yolov8.yaml文件信息,d为字典形式。
至此,则梳理明白了YOLOv8如何从model = YOLO('yolov8n.yaml')加载模型的原理。
如有错误,敬请指正。