Pytorch AI开发全流程常用代码工具总结

加载Model

:
models.common

DetectMultiBackend(自动在ONNX,PT等Model的权重文件格式中识别具体的网络存储类型并产生一个pytorch网络模型)
例子
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
weights=ROOT / “yolov5s.pt”, # model path or triton URL
dnn=False, # use OpenCV DNN for ONNX inference
data=ROOT / “data/coco128.yaml”, # dataset.yaml path

构建推理用数据集(推理数据(流)来源)

:
utils.dataloaders

LoadStreams() 主要用于从视频流中加载数据,如视频文件或网络摄像头输入,适合处理连续的帧数据。
LoadImages() 主要用于加载单个图像或从一个目录中批量加载静态图像文件
例子
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride) #从PC中的文件资源获取数据
source = str(source) #source对应一个PC中文件资源的地址
imgsz = check_img_size(imgsz, s=stride)
imgsz=(640, 640), # inference size (height, width)
stride, names, pt = model.stride, model.names, model.pt #通过DetectMultiBackend读取模型后,可以直接从得到的模型类中获取模型总的参数信息
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride) #从webcam中获取推理数据流
vid_stride:- 含义:视频加载时的帧间隔。
auto (pt):- 含义:自动调整图像尺寸的标志

部署模型到GPU

:
models
utils.torch_utils
:
models.warmup()(预热模型,(提前在GPU中运行一些散乱数据,提高模型首次运行的表现与稳定性))
select_device() (根据用户指定的设备(通常是设备ID或名称),自动选择合适的计算设备,无输入时默认优先GPU)
例子 :
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz))
device = select_device(device)
device = “”

(待推理数据(集/流)进入到GPU)

:
torch
:
from() (将 NumPy 数组 im 转换为 PyTorch 张量)
to() (将生成的 PyTorch 张量移动到 model.device 所指定的设备上)
例子 :
im = torch.from_numpy(im).to(model.device)
for path, im, im0s, vid_cap, s in dataset: #从dataset中获取im这一numpy张量格式图像

推理用数据集的使用

:
torch
:
from() (将 NumPy 数组 im 转换为 PyTorch 张量)
to() (将生成的 PyTorch 张量移动到 model.device 所指定的设备上)
model() (获取来自dataset转移到device上的im图像数据,进行推理)
getattr() (是 Python 的一个内置函数,用于从对象中获取指定属性的值。如果该属性不存在,还可以提供一个默认值来避免报错)
例子 :
for path, im, im0s, vid_cap, s in dataset:
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride) #dataset为一个LoadStream生成的迭代器,for每次迭代后,迭代器会返回一批元素,包括:
- path
- 含义:当前帧对应的视频流的路径或名称。对于视频文件,可能是文件路径;对于网络摄像头,可能是摄像头的标识符(如 "0" 表示第一个摄像头)。
- im
- 含义:当前视频帧经过预处理后的图像数据,通常以张量的形式返回,以便直接输入到深度学习模型中。
- im0s
- 含义:原始视频帧的图像数据(未经过预处理)。这通常是一个 numpy 数组,表示从视频流中提取的帧。
- vid_cap
- 含义:视频捕获对象,通常是一个 cv2.VideoCapture 对象,用于读取视频流的属性或从视频中抓取帧。这在使用 OpenCV 读取视频时很常见。
- s
- 含义:附加的字符串信息,通常用于存储或打印关于当前帧的一些信息。这可能包括帧的尺寸、帧率等。
im = torch.from_numpy(im).to(model.device)
for path, im, im0s, vid_cap, s in dataset: #从dataset中获取im这一numpy张量格式图像,并输入到模型对应的设备(一般为GPU)中
pred = model(im, augment=augment, visualize=visualize)
im = torch.from_numpy(im).to(model.device) #im与model需要再同一个设备上才可以将im输送给具体的网络模型并执行网络推理
p, im0, frame = path, im0s.copy(), getattr(dataset, “frame”, 0)

Yolov5中的数据处理

:
utils.general
:
LOGGER,
Profile,
check_file,
check_img_size,
check_imshow,
check_requirements,
colorstr,
cv2,
increment_path,
non_max_suppression,
print_args,
scale_boxes,
strip_optimizer,
xyxy2xywh,
例子 :
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
pred = model(im, augment=augment, visualize=visualize)
conf_thres=0.25, # confidence threshold 置信度低于该阈值的检测框都会被过滤掉
iou_thres=0.45, # NMS IOU threshold NMS 会移除 IoU 高于该阈值的检测框,以减少重叠的检测框,保留最具代表性的一个
classes=None, # filter by class: --class 0, or --class 0 2 3 指定只对哪些类别的预测结果应用 NMS
agnostic_nms=False, # class-agnostic NMS
max_det=1000, # maximum detections per image 表示每个图像中最多保留的检测框数量

环境配置

:
utils.general
argparse
:
- general.check_requirements (检查当前环境是否满足某个’‘requirements.txt’’ 文件中列出的依赖项,同时排除指定的依赖项)
- general.LOGGER (一个日志记录器对象,通常通过 Python 的 logging 模块创建并配置。它用于记录和管理程序的日志信息。)
- general.print_args (用于输出argparse实例的内部参数数据(配合opt=parser.parse_args()方法来使用))
- argparse.ArgumentParser (创建一个对象,通过调用其add_argument()方法来加入新的参数,用于解析命令行参数
例子 :
check_requirements(ROOT / “requirements.txt”, exclude=(“tensorboard”, “thop”))
parser = argparse.ArgumentParser()
parser.add_argument(“–data”, type=str, default=ROOT / “data/coco128.yaml”, help=“dataset.yaml path”)
接着跟上代码:
opt = parser.parse_args() (将解析后参数传递给opt,opt为一个命名空间对象,它包含了所有解析后的命令行参数的值。你可以通过 opt.<parameter_name> 的方式访问各个参数的值。)
vars(opt)将opt这个命名空间对象解包为字典(vars为python内置函数)
run(**vars(opt)) 将字典中的键值对解包为函数的关键字参数
LOGGER.info(f"WARNING ⚠️ confidence threshold {opt.conf_thres} > 0.001 produces invalid results")
.info(…): info 方法表示记录一条信息级别的日志。在日志记录级别中,info 通常用于记录一般性的信息,属于较低的严重性级别。其他常见的日志级别包括 debugwarningerrorcritical

  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值