YOLOv8逐行解析一:从.pt文件中加载模型

测试代码如下

from ultralytics import YOLO
# 创建一个yolo类的对象, 调用YOLO的构造函数
model = YOLO("weights/yolov8n.pt", task="detect")  # 加载模型

# 使用模型
model.train(data="ultralytics/cfg/datasets/coco8.yaml",
            epochs=10,
            batch=16,
            imgsz=640,
            workers=0,
            )  # 训练模型

开始逐行解析

ctrl点击YOLO转跳到yolo/model.py,YOLO类的声明处。

构造函数

    def __init__(self, model="yolov8n.pt", task=None, verbose=False):
        """Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'."""
        # 获取文件路径,model只是一个字符串,path是一个类表示路径
        path = Path(model)
        # 模型名称是否以-world结尾, 比如yolo-world.pt yolo-world.yaml, 所以不会走这个分支
        if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}:  # if YOLOWorld PyTorch model
            new_instance = YOLOWorld(path, verbose=verbose)
            self.__class__ = type(new_instance)
            self.__dict__ = new_instance.__dict__
        else:
            # Continue with default YOLO initialization
            # 调用父类构造函数, 次数model是"weights/yolov8n.pt" task是"detect"
            super().__init__(model=model, task=task, verbose=verbose)

在pycharm中点击super().__init__(model=model, task=task, verbose=verbose)进行转跳

YOLO类的父类构造函数

class Model(nn.Module):    
    def __init__(
        self,
        model: Union[str, Path] = "yolov8n.pt",
        task: str = None,
        verbose: bool = False,
    ) -> None:
        super().__init__() # 调用父类构造函数, 这里不做解释
        # 设置各种事件的回调函数
        self.callbacks = callbacks.get_default_callbacks()
        # 一些声明
        self.predictor = None  # 预测器
        self.model = None  # 模型, 这里是指一个对象而不是一个路径
        self.trainer = None  # 训练器
        self.ckpt = None  # 表示是否从pt文件中加载模型
        self.cfg = None  # 表示是否从yaml文件中加载模型
        self.ckpt_path = None # pt文件的路径
        self.overrides = {}  # 要覆盖的配置
        self.metrics = None  # 验证/训练过程中的度量指标
        
        self.session = None  # HUB会话,可能用于分布式训练或其它会话管理
        self.task = task  # task type # 设置task为"detect"
        model = str(model).strip() # 转换为字符串去掉首尾空格

        # Check if Ultralytics HUB model from https://hub.ultralytics.com
        # 检查是否是Ultralytics HUB中的模型,当模型路径是一个链接时可能是
        # 但现在是本地的模型文件, 所以不会进入
        if self.is_hub_model(model):
            # Fetch model from HUB
            checks.check_requirements("hub-sdk>=0.0.6")
            self.session = self._get_hub_session(model)
            model = self.session.model_file

        # Check if Triton Server model
        # 同上
        elif self.is_triton_model(model):
            self.model_name = self.model = model
            self.task = task
            return

        # Load or create new YOLO model
        # 检查模型后缀名, yaml文件会运行_new(), pt文件运行_load
        if Path(model).suffix in {".yaml", ".yml"}:
            self._new(model, task=task, verbose=verbose)
        else:
            self._load(model, task=task)

下期从.yaml中加载一个模型时从这里开始就会进入_new()函数,敬请期待

进入_load()函数

    def _load(self, weights: str, task=None) -> None:
        """
        Initializes a new model and infers the task type from the model head.

        Args:
            weights (str): model checkpoint to be loaded
            task (str | None): model task
        """
        # 当前weight为"weights/yolov8n.pt", task为"detect"
        # 检测是路径否为链接, 是则自动下载, 所以不会运行
        if weights.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")):
            weights = checks.check_file(weights)  # automatically download and return local filename
        # 补全文件后缀名, 如官方注释中yolov8n -> yolov8n.pt
        weights = checks.check_model_file_from_stem(weights)  # add suffix, i.e. yolov8n -> yolov8n.pt
        
        # 这里显然判断成立
        if Path(weights).suffix == ".pt":
            # 尝试加载模型,
            self.model, self.ckpt = attempt_load_one_weight(weights)
            # 此时self.model是一个torch网络模型的抽象类,具体来说就是DetectionModel类
            # model.args被设置为默认参数, 见ultralytics/cfg/default.yaml,加载为一个字典
            # ckpt是一个字典,保存了model的一些信息
            self.task = self.model.args["task"] # "detect"字符串
            self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
            self.ckpt_path = self.model.pt_path # pt文件路径
        else:
            # 不会运行
            weights = checks.check_file(weights)  # runs in all cases, not redundant with above call
            self.model, self.ckpt = weights, None
            self.task = task or guess_model_task(weights)
            self.ckpt_path = weights

        # 设置覆盖的参数
        self.overrides["model"] = weights # 模型文件的路径
        self.overrides["task"] = self.task # 检测然我
        self.model_name = weights # 模型名字

至此完成一个pt文件的加载

细看attempt_load_one_weight()

def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
    """加载单个模型权重。"""
    
    # 从文件系统中加载pt文件,ckpt为一个字典,包含了一个描述网络模型的抽象类
    ckpt, weight = torch_safe_load(weight)  # 加载ckpt
    
    # 合并两个字典,当有相同的键时后者会覆盖前者
    args = {**DEFAULT_CFG_DICT, **(ckpt.get("train_args", {}))}  # 合并模型和默认参数,优先使用模型参数
    
    # ema的值为None时,model即为ckpt键"model"的值,并将模型加载到指定设备上
    model = (ckpt.get("ema") or ckpt["model"]).to(device).float()  # FP32模型
    
    # 更新模型中的参数
    model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS}  # 将args附加到模型上
    
    # 设置模型在文件系统中的路径
    model.pt_path = weight  # 将*.pt文件路径附加到模型上
    
    # 推测模型的任务类型
    model.task = guess_model_task(model)
    
    # 如果模型中没有"stride"属性,则增加一个
    if not hasattr(model, "stride"):
        model.stride = torch.tensor([32.0])
    
    # 设置模型为评估模式,如果fuse为True且模型有fuse方法,则调用fuse方法
    model = model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval()  # 模型进入评估模式
    
    # 模块更新
    for m in model.modules():
        if hasattr(m, "inplace"):
            m.inplace = inplace
        elif isinstance(m, nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
            m.recompute_scale_factor = None  # 兼容torch 1.11.0
    
    # 返回模型和ckpt
    return model, ckpt

  • 4
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值