YOLOV5dataset.py代码注释与解析

YOLOv5代码注释版更新啦,注释的是最近的2021.07.14的版本,且注释更全
github: https://github.com/Laughing-q/yolov5_annotations

YOLOV5dataset.py代码注释与解析

本文主要对ultralytics\yolov5在训练时的数据加载模块的dataset.py代码进行注释和解析。当然dataset.py中还有其他时候(例如detect时)所用到的加载方法(例如LoadImages、LoadWebcam等),本文主要是对训练时用到的LoadImagesAndLabels类的相关注释。

yolov5其他代码解析

2020.08.19
增加对检测时用到的LoadImages、LoadStreams的注释,代码中没使用到LoadWebcam函数,就没有写注释
现在的yolov5中dateset.py中LoadImagesAndLabels函数的调用有一些数据增强方面的代码(random_affine——>random_perspective)更新,但不影响整体思路,就不重新加以注释了

mosaic增强

在这里要说一下,mosaic数据增强就是将四张图片拼接在一起传入网络训练,具体可以查看YOLOV4-mosaic数据增强详解。(该文章是基于pytorch YOLOV4代码做的解析)

矩形训练

在这里插入图片描述

正方形填充

可以看到yolov5会对图片进行填充,填充为正方形从而传入网络进行训练,可以看到这里面有很多冗余的信息,会让网络产生很多无意义的候选框,矩形训练就是减少这些冗余信息,减少网络产生的无意义的框的数量,加快网络训练速度。yolov5网络的总步长为32,所以其实只要图片边长能够整除32就可以了,不一定完全需要正方形图片传入网络,矩形训练就是将图片填充为最小的32的倍数边长,从而减小冗余信息。
在这里插入图片描述

矩形填充

值得一提的是,除了矩形训练,还有矩形推理,也就是在做检测的时候也这样填充,从而加快推理速度,减少推理时间。

import glob
import math
import os
import random
import shutil
import time
from pathlib import Path
from threading import Thread

import cv2
import numpy as np
import torch
from PIL import Image, ExifTags
from torch.utils.data import Dataset
from tqdm import tqdm

from utils.utils import xyxy2xywh, xywh2xyxy

help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.dng']
vid_formats = ['.mov', '.avi', '.mp4', '.mpg', '.mpeg', '.m4v', '.wmv', '.mkv']

# Get orientation exif tag
for orientation in ExifTags.TAGS.keys():
    if ExifTags.TAGS[orientation] == 'Orientation':
        break

# 此函数根据图片的信息获取图片的宽、高信息
def exif_size(img):
    # Returns exif-corrected PIL size
    s = img.size  # (width, height)
    try:
        rotation = dict(img._getexif().items())[orientation]
        if rotation == 6:  # rotation 270
            s = (s[1], s[0])
        elif rotation == 8:  # rotation 90
            s = (s[1], s[0])
    except:
        pass

    return s


# 根据LoadImagesAndLabels创建dataloader
def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False):
    """
    参数解析:
    path:包含图片路径的txt文件或者包含图片的文件夹路径
    imgsz:网络输入图片大小
    batch_size: 批次大小
    stride:网络下采样最大总步长
    opt:调用train.py时传入的参数,这里主要用到opt.single_cls,是否是单类数据集
    hyp:网络训练时的一些超参数,包括学习率等,这里主要用到里面一些关于数据增强(旋转、平移等)的系数
    augment:是否进行数据增强
    cache:是否提前缓存图片到内存,以便加快训练速度
    pad:设置矩形训练的shape时进行的填充
    rect:是否进行矩形训练
    """
    dataset = LoadImagesAndLabels(path, imgsz, batch_size,
                                  augment=augment,  # augment images
                                  hyp=hyp,  # augmentation hyperparameters
                                  rect=rect,  # rectangular training
                                  cache_images=cache,
                                  single_cls=opt.single_cls,
                                  stride=int(stride),
                                  pad=pad)

    batch_size = min(batch_size, len(dataset))
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             num_workers=nw,
                                             pin_memory=True,
                                             collate_fn=LoadImagesAndLabels.collate_fn)
    return dataloader, dataset

class LoadImages:  # for inference
    def __init__(self, path, img_size=640):
        p = str(Path(path))  # os-agnostic
        # os.path.abspath(p)返回p的绝对路径
        p = os.path.abspath(p)  # absolute path
        # 如果采用正则化表达式提取图片/视频,直接使用glob获取文件路径
        if '*' in p:
            files = sorted(glob.glob(p))  # glob
        # 如果path是一个文件夹,使用glob获取全部文件路径
        elif os.path.isdir(p):
            files = sorted(glob.glob(os.path.join(p, '*.*')))  # dir
            # 是文件则直接获取
        elif os.path.isfile(p):
            files = [p]  # files
        else:
            raise Exception('ERROR: %s does not exist' % p)
        # os.path.splitext分离文件名和后缀(后缀包含.)
        # 分别提取图片和视频文件路径
        images = [x for x in files if os.path.splitext(x)[-1].lower() in img_formats]
        videos = [x for x in files if os.path.splitext(x)[-1].lower() in vid_formats]
        # 图片与视频数量
        ni, nv = len(images), len(videos)

        self.img_size = img_size  # 输入图片size
        self.files = images + videos  # 整合图片和视频路径到一个列表
        self.nf = ni + nv  # number of files # 总的文件数量
        # 设置判断是否为视频的bool变量,方便后面单独对视频进行处理
        self.video_flag = [False] * ni + [True] * nv
        # 初始化模块信息,代码中对于mode=images与mode=video有不同处理
        self.mode = 'images'
        # 如果包含视频文件,这初始化opencv中的视频模块,cap=cv2.VideoCapture等
        if any(videos):
            self.new_video(videos[0])  # new video
        else:
            self.cap = None
        # nf如果小于0,则打印提示信息
        assert self.nf > 0, 'No images or videos found in %s. Supported formats are:\nimages: %s\nvideos: %s' % \
                            (p, img_formats, vid_formats)

    def __iter__(self):
        self.count = 0
        return self

    def __next__(self):
        # self.count == self.nf表示数据读取完了
        if self.count == self.nf:
            raise StopIteration
        # 获取文件路径
        path = self.files[self.count]
        # 如果该文件为视频,
        if self.video_flag[self.count]:
            # Read video
            # 修改mode为ideo
            self.mode = 'video'
            # 获取当前帧 画面,ret_val为一个bool变量,直到视频读取完毕之前都为True
            ret_val, img0 = self.cap.read()
            # 如果当前视频读取结束,则读取下一个视频
            if not ret_val:
                self.count += 1
                # 释放视频对象
                self.cap.release()
                # self.count == self.nf表示视频已经读取完了
                if self.count == self.nf:  # last video
                    raise StopIteration
                else:
                    path = self.files[self.count]
                    self.new_video(path)
                    ret_val, img0 = self.cap.read()
            # 当前读取的帧数
            self.frame += 1
            # 打印信息
            print('video %g/%g (%g/%g) %s: ' % (self.count + 1, self.nf, self.frame, self.nframes, path), end='')

        else:
            # Read image
            # 读取图片
            self.count += 1
            img0 = cv2.imread(path)  # BGR
            assert img0 is not None, 'Image Not Found ' + path
            # 打印信息
            print('image %g/%g %s: ' % (self.count, self.nf, path), end='')

        # Padded resize
        # 对图片进行resize+pad
        img = letterbox(img0, new_shape=self.img_size)[0]

        # Convert
        # BGR转为RGB格式,channel轴换到前面
        img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
        # 将数组内存转为连续,提高运行速度,(不转的话也可能会报错)
        img = np.ascontiguousarray(img)

        # cv2.imwrite(path + '.letterbox.jpg', 255 * img.transpose((1, 2, 0))[:, :, ::-1])  # save letterbox image
        # 返回路径,resize+pad的图片,原始图片,视频对象
        return path, img, img0, self.cap

    def new_video(self, path):
        # frame用来记录帧数
        self.frame = 0
        # 初始化视频对象
        self.cap = cv2.VideoCapture(path)
        # 视频文件中的总帧数
        self.nframes = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))

    def __len__(self):
        return self.nf  # number of files
        
class LoadStreams:  # multiple IP or RTSP cameras
    def __init__(self, sources='streams.txt', img_size=640):
        # 初始化mode为images
        self.mode = 'images'
        self.img_size = img_size
        # 如果sources为一个保存了多个视频流的文件
        # 获取每一个视频流,保存为一个列表
        if os.path.isfile(sources):
            with open(sources, 'r') as f:
                sources = [x.strip() for x in f.read().splitlines() if len(x.strip())]
        else:
            sources = [sources]
        # 视频流个数
        n = len(sources)
        self.imgs = [None] * n
        self.sources = sources
        for i, s in enumerate(sources):
            # Start the thread to read frames from the video stream
            # 打印当前视频/总视频数/视频流地址
            print('%g/%g: %s... ' % (i + 1, n, s), end='')
            # 如果source=0则打开摄像头,否则打开视频流地址
            cap = cv2.VideoCapture(0 if s == '0' else s)
            assert cap.isOpened(), 'Failed to open %s' % s
            # 获取视频的宽和长
            w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
            h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
            # 获取视频的帧率
            fps = cap.get(cv2.CAP_PROP_FPS) % 100
            # 读取当前画面
            _, self.imgs[i] = cap.read()  # guarantee first frame
            # 创建多线程读取视频流,daemon表示主线程结束时子线程也结束
            thread = Thread(target=self.update, args
  • 48
    点赞
  • 257
    收藏
    觉得还不错? 一键收藏
  • 87
    评论
YoloV5的`detect.py`中,打开摄像头的代码可以在`detect()`函数中找到,具体代码如下: ```python def detect(): ... # 如果使用摄像头,打开摄像头并获取视频流 if opt.source == '0' or opt.source.startswith('rtsp') or opt.source.startswith('http'): # 打开摄像头或者视频流 view_img = True cudnn.benchmark = True # set True to speed up constant image size inference dataset = LoadStreams(opt.source, img_size=imgsz) else: # 读取本地视频文件 save_img = True dataset = LoadImages(opt.source, img_size=imgsz) # 读取模型 model = attempt_load(weights, map_location=device) # load FP32 model imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size if half: model.half() # to FP16 # 获取类别名称 names = model.module.names if hasattr(model, 'module') else model.names # 进行推理 results = [] for path, img, im0s, vid_cap in dataset: # img : 当前帧的缩放后的图片 # im0s : 当前帧的原图 # 进行检测 t1 = torch_utils.time_synchronized() img = torch.from_numpy(img).to(device) img = img.half() if half else img.float() img /= 255.0 # 0 - 255 to 0.0 - 1.0 if img.ndimension() == 3: img = img.unsqueeze(0) # 获取预测结果 pred = model(img, augment=opt.augment)[0] # 进行后处理 pred = non_max_suppression(pred, conf_thres=conf_thres, iou_thres=iou_thres, classes=opt.classes, agnostic=agnostic_nms, max_det=max_det) t2 = torch_utils.time_synchronized() # 输出当前帧信息 for i, det in enumerate(pred): # detections per image if webcam: # batch_size >= 1 p, s, im0 = path[i], '%g: ' % i, im0s[i] else: p, s, im0 = path, '', im0s save_path = str(Path(out) / Path(p).name) txt_path = str(Path(out) / Path(p).stem) + (f'_{frame_i:06d}' if save_img else '') s += '%gx%g ' % img.shape[2:] # print string gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh if det is not None and len(det): # Rescale boxes from img_size to im0 size det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round() # Print results for c in det[:, -1].unique(): n = (det[:, -1] == c).sum() # detections per class s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string # Write results for *xyxy, conf, cls in det: if save_txt: # Write to file xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh with open(txt_path + '.txt', 'a') as f: f.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format if save_img or view_img: # Add bbox to image label = f'{names[int(cls)]} {conf:.2f}' plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3) # Print time (inference + NMS) print(f'{s}Done. ({t2 - t1:.3f}s)') # Stream results if view_img: cv2.imshow(str(p), im0) if cv2.waitKey(1) == ord('q'): # q to quit raise StopIteration # Save results (image with detections) if save_img: if dataset.mode == 'images': cv2.imwrite(save_path, im0) print(f'Done. ({time.time() - t0:.3f}s)') ``` 在上面的代码中,如果`opt.source`为`0`或者以`rtsp`或`http`开头,则表示打开摄像头或视频流,代码中会调用`LoadStreams`函数加载视频流。在`for path, img, im0s, vid_cap in dataset:`这一行代码中,`img`表示当前帧的缩放后的图片,`im0s`表示当前帧的原图。在代码中,会对当前帧的图片进行目标检测,并对检测结果进行后处理,最后将结果输出到屏幕上或保存到本地。如果需要显示视频流,则调用`cv2.imshow`函数将当前帧的原图显示到屏幕上。
评论 87
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值