pytorch训练后pt模型中保存内容详解(yolov8n.pt为例)

在 PyTorch 中,.pt 模型文件通常包含以下几类数据:

        模型参数:

                存储模型的权重和偏置参数。

        优化器状态:

                包含优化器的状态信息,以便在恢复训练时能够从中断的地方继续。

        训练状态:

                一些训练过程中的信息,例如当前的 epoch 数和训练进度。

        其他元数据:

                包括模型的配置、训练时使用的超参数等。

        在讲解pytorch pt(pth)文件中保存了什么内容之前,需要先了解pt在保存时保存了那些参数。

以YOLO系列pt保存代码来介绍说明:

1. 模型保存代码:

 def save_model(self):
        ckpt = {
            'epoch': self.epoch, #
            'best_fitness': self.best_fitness,
            'model': deepcopy(de_parallel(self.model)).half(),
            'ema': deepcopy(self.ema.ema).half(),
            'updates': self.ema.updates,
            'optimizer': self.optimizer.state_dict(),
            'train_args': vars(self.args),  # save as dict
            'date': datetime.now().isoformat(),
            'version': __version__}
        # Use dill (if exists) to serialize the lambda functions where pickle does not do this
        try:
            import dill as pickle
        except ImportError:
            import pickle
        # Save last, best and delete
        torch.save(ckpt, self.last, pickle_module=pickle)
        if self.best_fitness == self.fitness:
            torch.save(ckpt, self.best, pickle_module=pickle)
        if (self.epoch > 0) and (self.save_period > 0) and (self.epoch % self.save_period == 0):
            torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt', pickle_module=pickle)
        del ckpt

参数说明:

        'epoch': 当前的训练轮次数。

        'best_fitness': 最佳性能指标的数值。

        'model': 深拷贝(deepcopy)并将模型参数进行半精度(half)转换后的模型。

        'ema': 深拷贝并将指数移动平均模型参数进行半精度转换后的指数移动平均模型。

        'updates': 指数移动平均模型的更新次数。

        'optimizer': 优化器的状态字典(state_dict)。

        'train_args': 训练参数的字典表示,使用vars(self.args)将self.args对象转换为字典。

        'date': 当前的日期和时间,使用datetime.now().isoformat()获取。

        'version': 代码的版本号,通过__version__获取。

        其中:model中保存的模型的结构,train_args中保存训练时的一些参数(超参数)。

通过上述功能函数可以看到pytorch保存的pt文件中的内容。

补充说明:

        torch.save()函数用于将PyTorch模型保存到磁盘上的文件中,以便以后可以重新加载和使用。它的基本语法如下:

        torch.save(obj, f, pickle_module=<module 'pickle' from '...'>, pickle_protocol=2)

                obj是要保存的对象,通常是一个模型的状态字典(state_dict())。

                f是文件的路径或文件对象,用于存储模型。

                pickle_module是用于序列化的Python模块,默认为pickle。

                pickle_protocol是序列化时使用的协议版本,默认为2。

2. 模型加载介绍

下面通过Debug来详解pt中的具体内容:

首先加载模型,代码如下:

import sys
import argparse
import os
import struct
import torch
pt_file = "./yolov8n.pt"
wts_file = "./yolov8n.wts"
# Initialize
device = 'cpu'
# Load model
modelAll = torch.load(pt_file, map_location=device)
model = modelAll['model'].float()  # load to FP32
#model = torch.load(pt_file, map_location=device)['model'].float()  # load to FP32

anchor_grid = model.model[-1].anchors * model.model[-1].stride[..., None, None]
delattr(model.model[-1], 'anchors')
model.to(device).eval()
with open(wts_file, 'w') as f:
    f.write('{}\n'.format(len(model.state_dict().keys())))
    for k, v in model.state_dict().items():
        print("key={0}, v={1}".format(k,v))
        vr = v.reshape(-1).cpu().numpy()
        f.write('{} {} '.format(k, len(vr)))
        for vv in vr:
            f.write(' ')
            f.write(struct.pack('>f', float(vv)).hex())
        f.write('\n')

 Debug结果如下所示,分别对应save_model()中保存的内容

其中model(model = modelAll['model'].float())中内容如下:

       model的类型为DetectionModel,里面包含了模型结构(model.model)以及参数信息(model.args)及构造网络时的配置参数信息(model.yaml)以及目标类别及个数、stride等信息。 

3. 模型权重解析保存

        model.state_dict()是一个字典,键是参数的名称,值是对应的 tensor。

        其中保存着模型的权重(Weights)和偏置值(Biases)以及运行均值和方差(例如,Batch Normalization 层的 running_mean 和 running_var,用于推理时)等信息。

        权重解析保存代码如下:

with open(wts_file, 'w') as f:
    f.write('{}\n'.format(len(model.state_dict().keys())))
    for k, v in model.state_dict().items():
        print("key={0}, v={1}".format(k,v))
        vr = v.reshape(-1).cpu().numpy()
        f.write('{} {} '.format(k, len(vr)))
        for vv in vr:
            f.write(' ')
            f.write(struct.pack('>f', float(vv)).hex())
        f.write('\n')

代码功能介绍:

  1. 使用写模式打开一个文件 wts_file,以便保存模型的参数。
  2. 将模型参数的数量写入文件。
  3. 循环遍历每个参数的键名 k 和对应的值 v。
  4. 将参数 v 重塑为一维数组,并将其从 GPU 移动到 CPU(如果适用),然后转换为 NumPy 数组。
  5. 写入参数的名称和长度。
    for vv in vr:
        f.write(' ')
        f.write(struct.pack('>f', float(vv)).hex())

        遍历每个参数值,使用大端格式(‘>’)将其转换为浮点数并写入文件.

pt解包后保存后的文件内容如下:

上述代码可以将pt格式模型,转化为Nvidia TensorRT部署需要的文件。 

YoloV7是目标检测算法YOLO的最新版本,相较于之前的版本,它在模型结构、训练策略和速度等方面都有了较大的改进。test.py文件是用于测试已经训练好的模型的脚本,下面是对test.py文件的详细解释: 1. 导入必要的库和模块 ```python import argparse import os import platform import shutil import time from pathlib import Path import cv2 import torch import torch.backends.cudnn as cudnn import numpy as np from models.experimental import attempt_load from utils.datasets import LoadStreams, LoadImages from utils.general import check_img_size, check_requirements, check_imshow, \ non_max_suppression, apply_classifier, scale_coords, xyxy2xywh, strip_optimizer, set_logging from utils.plots import plot_one_box from utils.torch_utils import select_device, load_classifier, time_synchronized ``` 这里导入了一些必要的库和模块,比如PyTorch、OpenCV、NumPy等,以及用于测试的模型、数据集和一些工具函数。 2. 定义输入参数 ```python parser = argparse.ArgumentParser() parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)') parser.add_argument('--source', type=str, default='data/images', help='source') parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold') parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--view-img', action='store_true', help='display results') parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes') parser.add_argument('--nosave', action='store_true', help='do not save images/videos') parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3') parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS') parser.add_argument('--augment', action='store_true', help='augmented inference') parser.add_argument('--update', action='store_true', help='update all models') parser.add_argument('--project', default='runs/detect', help='save results to project/name') parser.add_argument('--name', default='exp', help='save results to project/name') parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') opt = parser.parse_args() ``` 这里使用Python的argparse库来定义输入参数,包括模型权重文件、输入数据源、推理尺寸、置信度阈值、NMS阈值等。 3. 加载模型 ```python # 加载模型 model = attempt_load(opt.weights, map_location=device) # load FP32 model imgsz = check_img_size(opt.img_size, s=model.stride.max()) # check img_size if device.type != 'cpu': model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once ``` 这里使用`attempt_load()`函数来加载模型,该函数会根据传入的权重文件路径自动选择使用哪个版本的YoloV7模型。同时,这里还会检查输入图片的大小是否符合模型的要求。 4. 设置计算设备 ```python # 设置计算设备 device = select_device(opt.device) half = device.type != 'cpu' # half precision only supported on CUDA # Initialize model model.to(device).eval() ``` 这里使用`select_device()`函数来选择计算设备(GPU或CPU),并将模型移动到选择的设备上。 5. 加载数据集 ```python # 加载数据集 if os.path.isdir(opt.source): dataset = LoadImages(opt.source, img_size=imgsz) else: dataset = LoadStreams(opt.source, img_size=imgsz) ``` 根据输入参数的数据源,使用`LoadImages()`或`LoadStreams()`函数来加载数据集。这两个函数分别支持从图片文件夹或摄像头/视频读取数据。 6. 定义类别和颜色 ```python # 定义类别和颜色 names = model.module.names if hasattr(model, 'module') else model.names colors = [[np.random.randint(0, 255) for _ in range(3)] for _ in names] ``` 这里从模型获取类别名称,同时为每个类别随机生成一个颜色,用于在图片绘制框和标签。 7. 定义输出文件夹 ```python # 定义输出文件夹 save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run (save_dir / 'labels' if opt.save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir ``` 这里使用`increment_path()`函数来生成输出文件夹的名称,同时创建相应的文件夹。 8. 开始推理 ```python # 开始推理 for path, img, im0s, vid_cap in dataset: t1 = time_synchronized() # 图像预处理 img = torch.from_numpy(img).to(device) img = img.half() if half else img.float() img /= 255.0 if img.ndimension() == 3: img = img.unsqueeze(0) # 推理 pred = model(img)[0] # 后处理 pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms) t2 = time_synchronized() # 处理结果 for i, det in enumerate(pred): # detections per image if webcam: # batch_size >= 1 p, s, im0 = path[i], f'{i}: ', im0s[i].copy() else: p, s, im0 = path, '', im0s save_path = str(save_dir / p.name) txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{counter}') + '.txt' if det is not None and len(det): det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round() for *xyxy, conf, cls in reversed(det): c = int(cls) label = f'{names[c]} {conf:.2f}' plot_one_box(xyxy, im0, label=label, color=colors[c], line_thickness=3) if opt.save_conf: with open(txt_path, 'a') as f: f.write(f'{names[c]} {conf:.2f}\n') if opt.save_crop: w = int(xyxy[2] - xyxy[0]) h = int(xyxy[3] - xyxy[1]) x1 = int(xyxy[0]) y1 = int(xyxy[1]) x2 = int(xyxy[2]) y2 = int(xyxy[3]) crop_img = im0[y1:y2, x1:x2] crop_path = save_path + f'_{i}_{c}.jpg' cv2.imwrite(crop_path, crop_img) # 保存结果 if opt.nosave: pass elif dataset.mode == 'images': cv2.imwrite(save_path, im0) else: if vid_path != save_path: # new video vid_path = save_path if isinstance(vid_writer, cv2.VideoWriter): vid_writer.release() # release previous video writer fourcc = 'mp4v' # output video codec fps = vid_cap.get(cv2.CAP_PROP_FPS) w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*fourcc), fps, (w, h)) vid_writer.write(im0) # 打印结果 print(f'{s}Done. ({t2 - t1:.3f}s)') # 释放资源 if cv2.waitKey(1) == ord('q'): # q to quit raise StopIteration elif cv2.waitKey(1) == ord('p'): # p to pause cv2.waitKey(-1) ``` 这里使用一个循环来遍历数据集的所有图像或视频帧,对每张图像或视频帧进行以下操作: - 图像预处理:将图像转换为PyTorch张量,并进行归一化和类型转换。 - 推理:将图像张量传入模型进行推理,得到预测结果。 - 后处理:对预测结果进行非极大值抑制、类别筛选等后处理操作,得到最终的检测结果。 - 处理结果:对每个检测框进行标签和颜色的绘制,同时可以选择保存检测结果的图片或视频以及标签信息的TXT文件。 - 释放资源:根据按键输入决定是否退出或暂停程序。 9. 总结 以上就是YoloV7的测试脚本test.py的详细解释,通过这个脚本可以方便地测试已经训练好的模型,并对检测结果进行可视化和保存等操作。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值