测试代码如下
from ultralytics import YOLO
# 创建一个yolo类的对象, 调用YOLO的构造函数
model = YOLO("weights/yolov8n.pt", task="detect") # 加载模型
# 使用模型
model.train(data="ultralytics/cfg/datasets/coco8.yaml",
epochs=10,
batch=16,
imgsz=640,
workers=0,
) # 训练模型
开始逐行解析
ctrl点击YOLO转跳到yolo/model.py,YOLO类的声明处。
构造函数
def __init__(self, model="yolov8n.pt", task=None, verbose=False):
"""Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'."""
# 获取文件路径,model只是一个字符串,path是一个类表示路径
path = Path(model)
# 模型名称是否以-world结尾, 比如yolo-world.pt yolo-world.yaml, 所以不会走这个分支
if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model
new_instance = YOLOWorld(path, verbose=verbose)
self.__class__ = type(new_instance)
self.__dict__ = new_instance.__dict__
else:
# Continue with default YOLO initialization
# 调用父类构造函数, 次数model是"weights/yolov8n.pt" task是"detect"
super().__init__(model=model, task=task, verbose=verbose)
在pycharm中点击super().__init__(model=model, task=task, verbose=verbose)进行转跳
YOLO类的父类构造函数
class Model(nn.Module):
def __init__(
self,
model: Union[str, Path] = "yolov8n.pt",
task: str = None,
verbose: bool = False,
) -> None:
super().__init__() # 调用父类构造函数, 这里不做解释
# 设置各种事件的回调函数
self.callbacks = callbacks.get_default_callbacks()
# 一些声明
self.predictor = None # 预测器
self.model = None # 模型, 这里是指一个对象而不是一个路径
self.trainer = None # 训练器
self.ckpt = None # 表示是否从pt文件中加载模型
self.cfg = None # 表示是否从yaml文件中加载模型
self.ckpt_path = None # pt文件的路径
self.overrides = {} # 要覆盖的配置
self.metrics = None # 验证/训练过程中的度量指标
self.session = None # HUB会话,可能用于分布式训练或其它会话管理
self.task = task # task type # 设置task为"detect"
model = str(model).strip() # 转换为字符串去掉首尾空格
# Check if Ultralytics HUB model from https://hub.ultralytics.com
# 检查是否是Ultralytics HUB中的模型,当模型路径是一个链接时可能是
# 但现在是本地的模型文件, 所以不会进入
if self.is_hub_model(model):
# Fetch model from HUB
checks.check_requirements("hub-sdk>=0.0.6")
self.session = self._get_hub_session(model)
model = self.session.model_file
# Check if Triton Server model
# 同上
elif self.is_triton_model(model):
self.model_name = self.model = model
self.task = task
return
# Load or create new YOLO model
# 检查模型后缀名, yaml文件会运行_new(), pt文件运行_load
if Path(model).suffix in {".yaml", ".yml"}:
self._new(model, task=task, verbose=verbose)
else:
self._load(model, task=task)
下期从.yaml中加载一个模型时从这里开始就会进入_new()函数,敬请期待
进入_load()函数
def _load(self, weights: str, task=None) -> None:
"""
Initializes a new model and infers the task type from the model head.
Args:
weights (str): model checkpoint to be loaded
task (str | None): model task
"""
# 当前weight为"weights/yolov8n.pt", task为"detect"
# 检测是路径否为链接, 是则自动下载, 所以不会运行
if weights.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")):
weights = checks.check_file(weights) # automatically download and return local filename
# 补全文件后缀名, 如官方注释中yolov8n -> yolov8n.pt
weights = checks.check_model_file_from_stem(weights) # add suffix, i.e. yolov8n -> yolov8n.pt
# 这里显然判断成立
if Path(weights).suffix == ".pt":
# 尝试加载模型,
self.model, self.ckpt = attempt_load_one_weight(weights)
# 此时self.model是一个torch网络模型的抽象类,具体来说就是DetectionModel类
# model.args被设置为默认参数, 见ultralytics/cfg/default.yaml,加载为一个字典
# ckpt是一个字典,保存了model的一些信息
self.task = self.model.args["task"] # "detect"字符串
self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
self.ckpt_path = self.model.pt_path # pt文件路径
else:
# 不会运行
weights = checks.check_file(weights) # runs in all cases, not redundant with above call
self.model, self.ckpt = weights, None
self.task = task or guess_model_task(weights)
self.ckpt_path = weights
# 设置覆盖的参数
self.overrides["model"] = weights # 模型文件的路径
self.overrides["task"] = self.task # 检测然我
self.model_name = weights # 模型名字
至此完成一个pt文件的加载
细看attempt_load_one_weight()
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
"""加载单个模型权重。"""
# 从文件系统中加载pt文件,ckpt为一个字典,包含了一个描述网络模型的抽象类
ckpt, weight = torch_safe_load(weight) # 加载ckpt
# 合并两个字典,当有相同的键时后者会覆盖前者
args = {**DEFAULT_CFG_DICT, **(ckpt.get("train_args", {}))} # 合并模型和默认参数,优先使用模型参数
# ema的值为None时,model即为ckpt键"model"的值,并将模型加载到指定设备上
model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32模型
# 更新模型中的参数
model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # 将args附加到模型上
# 设置模型在文件系统中的路径
model.pt_path = weight # 将*.pt文件路径附加到模型上
# 推测模型的任务类型
model.task = guess_model_task(model)
# 如果模型中没有"stride"属性,则增加一个
if not hasattr(model, "stride"):
model.stride = torch.tensor([32.0])
# 设置模型为评估模式,如果fuse为True且模型有fuse方法,则调用fuse方法
model = model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval() # 模型进入评估模式
# 模块更新
for m in model.modules():
if hasattr(m, "inplace"):
m.inplace = inplace
elif isinstance(m, nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
m.recompute_scale_factor = None # 兼容torch 1.11.0
# 返回模型和ckpt
return model, ckpt