yolov3代码详解(五)

Pytorch | yolov3代码详解五

detect.py

from __future__ import division

from models import *
from utils.utils import *
from utils.datasets import *

import os
import sys
import time
import datetime
import argparse

from PIL import Image

import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.ticker import NullLocator
##########################################################################
#检测
##########################################################################

#detect.py主要的工作过程:
#1.解析命令行输入的各种参数,如果没有就使用默认的参数
#2.打印出当前的各种参数
#3.创建model
#4.加载model的权重
#5.加载测试图像
#6.加载data/coco.names中的类别名称
#7.算出batch中所有图片的地址img_paths和检测结果detections
#8.为detections里每个类别的物体选择一种颜色,把检测到的bboxes画到图上



"""
(1)import argparse    首先导入模块
(2)parser = argparse.ArgumentParser()    创建一个解析对象
(3)parser.add_argument()    向该对象中添加你要关注的命令行参数和选项
(4)parser.parse_args()    进行解析
"""


if __name__ == "__main__":

    #1.解析命令行输入的各种参数,如果没有就使用默认的参数
    parser = argparse.ArgumentParser()
    parser.add_argument("--image_folder", type=str, default="data/samples", help="path to dataset")
    parser.add_argument("--model_def", type=str, default="config/yolov3.cfg", help="path to model definition file")
    parser.add_argument("--weights_path", type=str, default="weights/yolov3.weights", help="path to weights file")
    parser.add_argument("--class_path", type=str, default="data/coco.names", help="path to class label file")
    parser.add_argument("--conf_thres", type=float, default=0.8, help="object confidence threshold")  #目标置信阈值
    parser.add_argument("--nms_thres", type=float, default=0.4, help="iou thresshold for non-maximum suppression")
    parser.add_argument("--batch_size", type=int, default=1, help="size of the batches")
    parser.add_argument("--n_cpu", type=int, default=0, help="number of cpu threads to use during batch generation")
    parser.add_argument("--img_size", type=int, default=416, help="size of each image dimension")    #图片大小
    parser.add_argument("--checkpoint_model", type=str, help="path to checkpoint model")            #检测模型路径
    opt = parser.parse_args()
     #2.打印出当前的各种参数
    print(opt)
    #选择是否使用GPU设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    os.makedirs("output", exist_ok=True)

    #3.创建模型
    # Set up model  这条语句加载darkent模型,即YOLOv3模型。Darknet模型在model.py中进行定义
    model = Darknet(opt.model_def, img_size=opt.img_size).to(device)
    
    #4.加载模型的权重
    #Darknet(YOLOv3)模型基本加载完毕,接下来就是,加载权重.weights文件,进行预测。
    #查找weights_path路径下的.weights的文件
    if opt.weights_path.endswith(".weights"):
        # Load darknet weights
        model.load_darknet_weights(opt.weights_path)
    else:
        # Load checkpoint weights
        model.load_state_dict(torch.load(opt.weights_path))

    # model.eval(),让model变成测试模式,这主要是对dropout和batch normalization的
    # 操作在训练和测试的时候是不一样的
    """
    model.train() :启用 BatchNormalization 和 Dropout
    model.eval() :不启用 BatchNormalization 和 Dropout
    """
    model.eval()  # Set in evaluation mode



    #5.加载测试图像
    """
    加载数据
    DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, 
    num_workers=0, collate_fn=default_collate, pin_memory=False, 
    drop_last=False)

    dataset:加载的数据集(Dataset对象)
    batch_size:batch size
    shuffle::是否将数据打乱
    sampler: 样本抽样,后续会详细介绍
    num_workers:使用多进程加载的进程数,0代表不使用多进程
    collate_fn: 如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可
    pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些
    drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃
    """

    """
    ImageFolder(root,transform=None,target_transform=None,loader=
    default_loader)
    root : 在指定的root路径下面寻找图片
    transform: 对PIL Image进行转换操作,transform 输入是loader读取图片返回的对象
    target_transform :对label进行变换
    loader: 指定加载图片的函数,默认操作是读取PIL image对象
    —————————————————————————————————————————————————————————————————————————————
    文中的函数 由重新写了,在datasets中
    ImageFolder是遍历文件夹下的测试图片,完整定义如下。
    ImageFolder中的__getitem__()函数会把图像归一化处理成img_size(默认416)大小的图片。
    """
    dataloader = DataLoader(
        ImageFolder(opt.image_folder, img_size=opt.img_size),
        batch_size=opt.batch_size,
        shuffle=False,
        num_workers=opt.n_cpu,
    )

    #6.加载data/coco.names中的类别名称
    #类别信息
    classes = load_classes(opt.class_path)  # Extracts class labels from file

    Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

    imgs = []  # Stores image paths  存储图像路径
    img_detections = []  # Stores detections for each image index  存储每个图像索引的检测

    print("\nPerforming object detection:")
    prev_time = time.time()  #返回当前时间的时间戳
     #7.算出batch中所有图片的地址img_paths和检测结果detections
    for batch_i, (img_paths, input_imgs) in enumerate(dataloader):
        # Configure input  输入, Variable变量
        input_imgs = Variable(input_imgs.type(Tensor))

        # Get detections
        with torch.no_grad():
            #把图像放进模型中,得到检测结果。这里是通过Darknet的forward()函数得到检测结果。
            detections = model(input_imgs)     #model是上面model = Darknet(opt.model_def, img_size=opt.img_size).to(device)得到的
            #在获取检测框之后,需要使用非极大值抑制来筛选框。(在utils中)
            detections = non_max_suppression(detections, opt.conf_thres, opt.nms_thres)

        # Log progress  记录进度
        current_time = time.time()   #返回当前时间的时间戳
        #timedelta代表两个datetime之间的时间差
        inference_time = datetime.timedelta(seconds=current_time - prev_time)
        prev_time = current_time
        print("\t+ Batch %d, Inference Time: %s" % (batch_i, inference_time))

        # Save image and detections
        #extend() 函数用于在列表末尾一次性追加另一个序列中的多个值(用新列表扩展原来的列表)
        imgs.extend(img_paths)
        img_detections.extend(detections)

    # Bounding-box colors
    cmap = plt.get_cmap("tab20b")
    colors = [cmap(i) for i in np.linspace(0, 1, 20)]

    print("\nSaving images:")
    # Iterate through images and save plot of detections
    for img_i, (path, detections) in enumerate(zip(imgs, img_detections)):

        print("(%d) Image: '%s'" % (img_i, path))

        # Create plot
        img = np.array(Image.open(path))   #PIL.Image.open()专接图片路径,用来直接读取该路径指向的图片。要求路径必须指明到哪张图,不能只是所有图所在的文件夹;
        #img为原始图片
        plt.figure()
        fig, ax = plt.subplots(1)
        ax.imshow(img)

        # Draw bounding boxes and labels of detections
        if detections is not None:
            # Rescale boxes to original image  将框重缩放到原始图像
            detections = rescale_boxes(detections, opt.img_size, img.shape[:2])
            unique_labels = detections[:, -1].cpu().unique()   #得到类别
            n_cls_preds = len(unique_labels)
            bbox_colors = random.sample(colors, n_cls_preds)   #根据类别不同,得到不同框的颜色
            for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections:

                print("\t+ Label: %s, Conf: %.5f" % (classes[int(cls_pred)], cls_conf.item()))

                box_w = x2 - x1
                box_h = y2 - y1

                color = bbox_colors[int(np.where(unique_labels == int(cls_pred))[0])]
                # Create a Rectangle patch  创建矩形
                bbox = patches.Rectangle((x1, y1), box_w, box_h, linewidth=2, edgecolor=color, facecolor="none")
                # Add the bbox to the plot  添加矩形到图片
                ax.add_patch(bbox)
                # Add label
                plt.text(
                    x1,
                    y1,
                    s=classes[int(cls_pred)],
                    color="white",
                    verticalalignment="top",
                    bbox={"color": color, "pad": 0},
                )

        # Save generated image with detections
        plt.axis("off")
        plt.gca().xaxis.set_major_locator(NullLocator())
        plt.gca().yaxis.set_major_locator(NullLocator())
        filename = path.split("/")[-1].split(".")[0]
        plt.savefig(f"output/{filename}.png", bbox_inches="tight", pad_inches=0.0)
        plt.close()

YOLOv3是一种目标检测算法,它在PyTorch框架下实现。你可以在GitHub上找到YOLOv3的PyTorch版本代码,地址是https://github.com/ultralytics/yolov3。这个代码库提供了一些教程和运行结果,但不一定能直接运行成功。你可以在同目录下新建一个.ipynb文件,并在其中运行代码"%run detect.py"来尝试运行。\[1\] 在代码解读方面,首先需要准备数据集和关键文件。然后,代码的大致流程包括数据与标签的读取、模型构造、前向传播和计算损失。具体来说,模型构造部分包括构建convolutional层、rout层和shortcut层,以及构建yolo层。\[2\] 如果你想深入了解YOLOv3的PyTorch版本代码,可以参考官方教程,地址是https://github.com/ultralytics/yolov3/wiki/Train-Custom-Data。这个教程提供了更详细的训练自定义数据集的指导。\[3\] #### 引用[.reference_title] - *1* *3* [YOLOv3 Pytorch代码及原理分析(一):跑通代码](https://blog.csdn.net/weixin_43605641/article/details/107524168)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insertT0,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* [YOLOV3 Pytorch版本代码解读](https://blog.csdn.net/Weary_PJ/article/details/128749270)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insertT0,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值