2021SC@SDUSC山东大学软件学院软件工程应用与实践--YOLOV5代码分析(八)plots.py-1

2021SC@SDUSC

前言

这篇分析plot.py文件,就如其名称一样,主要是一些用以展示的代码,也不是核心代码

外部库

from copy import copy
from pathlib import Path

import cv2
import math
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sn
import torch
from PIL import Image, ImageDraw, ImageFont

from utils.general import user_config_dir, is_ascii, xywh2xyxy, xyxy2xywh
from utils.metrics import fitness

copy:用于对象的拷贝操作,该模块只提供了两个主要的方法,cpoy.cpoy与cpoy.deepcopy,分别表示浅复制和深复制

Path,cv2,math,numpy,pandas在general.py中已经介绍过了

matplotlib:是python最著名的绘图库,提供了一整套和matlab相似的命令API,是这个文件的主要外部库

seaborn:基于matplotlib的python可视化库,是在matplotlib的基础上进行了更高级的API封装。

 Colors类

class Colors:
    # Ultralytics color palette https://ultralytics.com/
    def __init__(self):
        # hex = matplotlib.colors.TABLEAU_COLORS.values()
        hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
               '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
        self.palette = [self.hex2rgb('#' + c) for c in hex]
        self.n = len(self.palette)

    def __call__(self, i, bgr=False):
        c = self.palette[int(i) % self.n]
        return (c[2], c[1], c[0]) if bgr else c

    @staticmethod
    def hex2rgb(h):  # rgb order (PIL)
        return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))

hex:十六进制格式的颜色

palette:rgbgeshideyanse

n:数组长度

 hex2rgb函数将以十六进制表示的颜色转换为RGB格式

call函数在调用时返回索引为i的颜色,当i超过n时用i模n的索引来取得颜色

check_font函数

def check_font(font='Arial.ttf', size=10):
    # Return a PIL TrueType Font, downloading to CONFIG_DIR if necessary
    font = Path(font)
    font = font if font.exists() else (CONFIG_DIR / font.name)
    try:
        return ImageFont.truetype(str(font) if font.exists() else font.name, size)
    except Exception as e:  # download if missing
        url = "https://ultralytics.com/assets/" + font.name
        print(f'Downloading {url} to {font}...')
        torch.hub.download_url_to_file(url, str(font))
        return ImageFont.truetype(str(font), size)

 font:检查的字体

该函数检查有否有对应的字体文件,没有从网上下载到对应的路径

PIL的ImageFont模块定义了相同名称的类,即ImageFont类。这个类的实力存储bitmap字体,用于ImageDraw类的text()方法,不多讲解,感兴趣的可以参考ImageFont 模块 — Pillow (PIL Fork) 8.4.0 文档

Annotator类

class Annotator:
    check_font()  # download TTF if necessary

    # YOLOv5 Annotator for train/val mosaics and jpgs and detect/hub inference annotations
    def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=True):
        assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
        self.pil = pil
        if self.pil:  # use PIL
            self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
            self.draw = ImageDraw.Draw(self.im)
            self.font = check_font(font, size=font_size or max(round(sum(self.im.size) / 2 * 0.035), 12))
            self.fh = self.font.getsize('a')[1] - 3  # font height
        else:  # use cv2
            self.im = im
        self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2)  # line width

    def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):
        # Add one xyxy box to image with label
        if self.pil or not is_ascii(label):
            self.draw.rectangle(box, width=self.lw, outline=color)  # box
            if label:
                w, h = self.font.getsize(label)  # text width
                self.draw.rectangle([box[0], box[1] - self.fh, box[0] + w + 1, box[1] + 1], fill=color)
                # self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls')  # for PIL>8.0
                self.draw.text((box[0], box[1] - h), label, fill=txt_color, font=self.font)
        else:  # cv2
            c1, c2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
            cv2.rectangle(self.im, c1, c2, color, thickness=self.lw, lineType=cv2.LINE_AA)
            if label:
                tf = max(self.lw - 1, 1)  # font thickness
                w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0]
                c2 = c1[0] + w, c1[1] - h - 3
                cv2.rectangle(self.im, c1, c2, color, -1, cv2.LINE_AA)  # filled
                cv2.putText(self.im, label, (c1[0], c1[1] - 2), 0, self.lw / 3, txt_color, thickness=tf,
                            lineType=cv2.LINE_AA)

    def rectangle(self, xy, fill=None, outline=None, width=1):
        # Add rectangle to image (PIL-only)
        self.draw.rectangle(xy, fill, outline, width)

    def text(self, xy, text, txt_color=(255, 255, 255)):
        # Add text to image (PIL-only)
        w, h = self.font.getsize(text)  # text width, height
        self.draw.text((xy[0], xy[1] - h + 1), text, fill=txt_color, font=self.font)

    def result(self):
        # Return annotated image as array
        return np.asarray(self.im)

init方法:

im:图片

line_width:线宽

font_size:字体大小

font:字体名称

pil:是否使用pillow

如果使用pillow,将图片格式转换为pillow的格式,fh为字体高度

ImageDraw提供简单的二维图像Image物体,可以使用此模块创建新图像、对现有图像进行注释或润色,具体参考ImageDraw 模块 — Pillow (PIL Fork) 8.4.0 文档

lw为线宽

 box_label方法:向图片中增加一个xyxy的box,并且加上标签

box:xyxy的box

label:标签

无论使用PIL或者opencv都是在对图像加一个box,其格式是xyxy,即box左上角的点坐标和右下角点的坐标,并且标注box的标签

rectangle 方法:

向图像中画一个长方形

text方法:

向图像中添加box的标签

result方法:

返回最终的图像,其格式是numpy数组

该类实现了向图片中画出预测框并且添加标签

 

 如图是经过操作后的图像,标注出了预测框以及预测出来的类别以及置信度

hist2d函数

def hist2d(x, y, n=100):
    # 2d histogram used in labels.png and evolve.png
    xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
    hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
    xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
    yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
    return np.log(hist[xidx, yidx])

根据x,y的直方图分布,来返回绘制颜色,区间数量多的颜色更亮,反之更暗

x和y都是np数组

np.linspace(start,stop,num,endpoint,retstep,dtype)

在指定的间隔内返回均匀间隔的数字 ,返回num均匀分布的样本在[start,stop]之间

np.clip(a,a_min,a_max,out=None)是将a限定在a_min和a_max之间,当a大于a_max时返回a_max,a小于a_min返回a_min,否则返回a本身

np.histogram2d可以将两个二维数组做出它的直方图

butter_lowpass_filtfilt函数

def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
    from scipy.signal import butter, filtfilt

    # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
    def butter_lowpass(cutoff, fs, order):
        nyq = 0.5 * fs
        normal_cutoff = cutoff / nyq
        return butter(order, normal_cutoff, btype='low', analog=False)

    b, a = butter_lowpass(cutoff, fs, order=order)
    return filtfilt(b, a, data)  # forward-backward filter

data:原数据

cutoff:被丢掉的频率

fs:滤波器大小

这个函数实现了低通滤波,即保留图像中频率比较低的部分,丢掉频率高的部分,“低通”就是低频能够通过,高频无法通过。

butter为配置滤波器,filtfilt实现滤波

 具体可参考官网scipy.signal.butter — SciPy v1.7.1 Manual

 output_to_target函数

def output_to_target(output):
    # Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
    targets = []
    for i, o in enumerate(output):
        for *box, conf, cls in o.cpu().numpy():
            targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf])
    return np.array(targets)

output:模型的输出

该函数将模型的输出转换为我们想要的格式,即[batch_id,class_id,x,y,w,h,conf] 

output的格式为[boxes,conf,cla],分别代表了预测框、置信度、类别

标签的格式为[batch_id,class_id,x,y,w,h,conf]* M,M为整个batch的预测框数量。

plot_images函数

def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=1920, max_subplots=16):

 images:一个batch的图片

labels:一个batch的标签

paths:一个batch的文件名

fname:保存可视化之后大图的文件路径

names:类别名

max_size:限制每张可视化图片的最大图片大小

max_subplots:最多可视化batch_size=16张图片

# Plot image grid with labels
    if isinstance(images, torch.Tensor):
        images = images.cpu().float().numpy()
    if isinstance(targets, torch.Tensor):
        targets = targets.cpu().numpy()
    if np.max(images[0]) <= 1:
        images *= 255.0  # de-normalise (optional)

 将images和labels从tensor转换为numpy类型

如果images为0-1,将其乘上255转换为0-255

bs, _, h, w = images.shape  # batch size, _, height, width
bs = min(bs, max_subplots)  # limit plot images
ns = np.ceil(bs ** 0.5)  # number of subplots (square)

bs,h,w分别为batch_size,图片的高度、宽度

    # Build Image
    mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8)  # init
    for i, im in enumerate(images):
        if i == max_subplots:  # if last batch has fewer images than we expect
            break
        x, y = int(w * (i // ns)), int(h * (i % ns))  # block origin
        im = im.transpose(1, 2, 0)
        mosaic[y:y + h, x:x + w, :] = im

mosaic为初始化大图

对images进行遍历,x和y为转化为mosaic的像素位置

这块代码就是将images进行放大,复制到mosaic

 

    # Resize (optional)
    scale = max_size / ns / max(h, w)
    if scale < 1:
        h = math.ceil(scale * h)
        w = math.ceil(scale * w)
        mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))

对mosaic进行resize,h和w分别为新的高和宽,scale为缩小倍数

    # Annotate
    fs = int((h + w) * ns * 0.01)  # font size
    annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs)
    for i in range(i + 1):
        x, y = int(w * (i // ns)), int(h * (i % ns))  # block origin
        annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2)  # borders
        if paths:
            annotator.text((x + 5, y + 5 + h), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220))  # filenames
        if len(targets) > 0:
            ti = targets[targets[:, 0] == i]  # image targets
            boxes = xywh2xyxy(ti[:, 2:6]).T
            classes = ti[:, 1].astype('int')
            labels = ti.shape[1] == 6  # labels if no conf column
            conf = None if labels else ti[:, 6]  # check for confidence presence (label vs pred)

            if boxes.shape[1]:
                if boxes.max() <= 1.01:  # if normalized with tolerance 0.01
                    boxes[[0, 2]] *= w  # scale to pixels
                    boxes[[1, 3]] *= h
                elif scale < 1:  # absolute coords need scale if image scales
                    boxes *= scale
            boxes[[0, 2]] += x
            boxes[[1, 3]] += y
            for j, box in enumerate(boxes.T.tolist()):
                cls = classes[j]
                color = colors(cls)
                cls = names[cls] if names else cls
                if labels or conf[j] > 0.25:  # 0.25 conf thresh
                    label = f'{cls}' if labels else f'{cls} {conf[j]:.1f}'
                    annotator.box_label(box, label, color=color)
    annotator.im.save(fname)  # save

接下来就是对图片进行标注,fs为font size,annotator为上面定义的类

x和y为左上角的点,然后用anatator画一个长方形,如果path不为空并标注出box的类别标签

image_targets为当前batch的标签,boxes、classes、labels、conf分别是预测框、类别、是否可视化标签、置信度,其中labels表示当image_targets.shape[1]==6时需要可视化的是标签而不是预测框。

如果预测框是归一化了的将其放大到原图大小,否则乘以scale_factor

接下来对boxes的坐标加上左上角的坐标,boxes原先的坐标是基于当前grid的左上角的相对坐标,加上左上角的坐标变换为全局坐标

接下来在子图上画框,cls、color为类别和颜色,如果是画预测框并且conf>0.25,则画出一个预测框,设置conf>0.25是为了去除掉那些重复预测出来的框。

最后将其保存在相应的路径下。

总结

本篇文章比较重要的部分就是对图片进行画框和标注类别的处理,还有一些方法还没有介绍到,将在下一篇文章继续介绍这部分内容。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值