《博主简介》
小伙伴们好,我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。
✌更多学习资源,可关注公-仲-hao:【阿旭算法与机器学习】,共同学习交流~
👍感谢小伙伴们点赞、关注!
《------往期经典推荐------》
二、机器学习实战专栏【链接】,已更新31期,欢迎关注,持续更新中~~
三、深度学习【Pytorch】专栏【链接】
四、【Stable Diffusion绘画系列】专栏【链接】
五、YOLOv8改进专栏【链接】,持续更新中~~
六、YOLO性能对比专栏【链接】,持续更新中~
《------正文------》
引言
计算机视觉领域的一个案例研究是创建一个解决方案,使系统能够像人类一样“看到”和“理解”图像中的对象。这个任务被称为目标检测。在分类任务中,模型识别图像中对象的类型(例如,它是猫还是狗的图片)。然而,在对象检测中,目标不仅是识别对象,而且还要通过在其周围绘制边界框来确定其在图像中的位置。
物体检测在日常生活中起着至关重要的作用,从安全方面到交通控制,甚至是自动驾驶汽车技术。这项任务的复杂性主要体现在两个方面:首先,如何检测图像中不同大小、形状和方向的物体;其次,如何快速准确地执行这种检测,即使在经常有噪声和复杂背景的现实世界中也是如此。
Faster-R-CNN概述
最流行的对象检测方法之一是Faster-R-CNN(基于区域的卷积神经网络)。Faster-R-CNN是对之前开发的R-CNN和Fast R-CNN方法的改进,在检测速度和准确性方面都有显着提高。虽然现在有许多更新和更快的方法与Faster R-CNN相比,但仍然值得理解和尝试这种方法作为对象检测技术的介绍。
Faster R-CNN将区域建议网络(RPN)集成到对象检测架构中。借助RPN,Faster R-CNN可以使用卷积网络提取的特征更有效地生成区域建议(可能包含对象的区域)。这允许在单个网络中进行端到端的对象检测,消除了以前依赖于外部过程(如选择性搜索)的方法中存在的瓶颈。为了更深入地了解Faster R-CNN的工作原理,您可以参考本文。
实现Faster R-CNN用于对象检测
在这里,我们不会从头开始训练模型,而是使用torchvision中提供的预训练模型。本文中使用的图像可以在Kaggle上找到。首先,我们需要为这个项目导入一些库:
import cv2
import torch
import requests
import numpy as np
import torchvision
from PIL import Image
from torch import no_grad
import matplotlib.pyplot as plt
from torchvision import transforms
我们导入的库与计算机视觉有关。例如,OpenCV(cv2
)用于图像处理,NumPy(np
)用于计算,PIL用于加载图像,Matplotlib(plt
)用于可视化,Requests用于从Web下载图像。
创建辅助函数
# Function to get predictions with optional filtering by object and threshold
def get_predictions(pred, threshold=0.8, objects=None):
"""
Assign a string name to predicted classes and filter out predictions below a given threshold.Args:
pred: List containing tuples with class labels, probabilities, and bounding boxes.
threshold: Minimum probability required to consider a prediction valid.
objects: Optional list of object names to filter predictions.
Returns:
List of tuples containing class name, probability, and bounding box for each valid prediction.
"""
predicted_classes = [(COCO_INSTANCE_CATEGORY_NAMES[i], p, [(box[0], box[1]), (box[2], box[3])])
for i, p, box in zip(list(pred[0]['labels'].numpy()),
pred[0]['scores'].detach().numpy(),
list(pred[0]['boxes'].detach().numpy()))]
predicted_classes = [stuff for stuff in predicted_classes if stuff[1] > threshold]
if objects and predicted_classes:
predicted_classes = [(name, p, box) for name, p, box in predicted_classes if name in objects]
return predicted_classes
get_predictions
函数用于处理和过滤对象检测模型做出的预测。它接受三个参数:pred
作为模型的原始输出,阈值
用于过滤低置信度的预测,对象
作为可选列表用于过滤基于特定类的预测。首先,该函数通过创建包含类名、检测概率和边界框坐标的元组列表,将原始预测转换为更可读的格式。然后过滤掉不符合概率阈值的预测,如果提供了objects
参数,则只保留与指定对象名称匹配的预测。结果是一个经过过滤的元组列表,可供进一步分析或可视化。
# Function to draw bounding boxes around detected objects
def draw_box(predicted_classes, image, rect_th=1, text_size=1, text_th=1):
"""
Draw bounding boxes and labels around detected objects in an image.
Args:
predicted_classes: List of tuples containing class name, probability, and bounding box.
image: Image tensor on which boxes and labels will be drawn.
rect_th: Thickness of the rectangle.
text_size: Font size of the label text.
text_th: Thickness of the label text.
"""
img = (np.clip(cv2.cvtColor(np.clip(image.numpy().transpose((1, 2, 0)), 0, 1), cv2.COLOR_RGB2BGR), 0, 1) * 255).astype(np.uint8).copy()
for predicted_class in predicted_classes:
label, probability, box = predicted_class
t, l = box[0]
r, b = box[1]
t, l, r, b = [round(item) for item in [t, l, r, b]]
cv2.rectangle(img, (t, l), (r, b), (0, 255, 0), rect_th) # Draw Rectangle
cv2.putText(img, f"{label}: {str(round(probability, 2))}", (t, l), cv2.FONT_HERSHEY_SIMPLEX, text_size, (0, 255, 0), thickness=text_th)
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.show()
draw_box函数用于通过在检测到的对象周围绘制边界框和标签来可视化对象检测模型的结果。此函数还接受预测类列表、要处理的图像、边界框的厚度以及标签文本的大小和厚度的参数。该函数将图像从张量格式转换为NumPy数组。然后,它迭代预测类的列表,提取类信息,概率和边界框坐标,使用OpenCV在图像上绘制。之后,包含类名和概率的标签被添加到边界框的左上角。最后,图像被转换回RGB格式,并使用Matplotlib显示,使我们能够看到检测到的对象沿着及其边界框和标签。
# Function to clear GPU memory and delete images to free up RAM
def save_RAM(image_=False):
"""
Clear GPU memory and delete image variables to free up RAM.
Args:
image_: Boolean flag to indicate if the image object should be deleted.
"""
torch.cuda.empty_cache()
global image, img, pred
del img, pred
if image_:
image.close()
del image
save_RAM函数用于管理和释放内存,特别是在GPU内存有限且需要仔细管理的环境中,例如在深度学习模型推理期间。该函数主要用于清除GPU内存,并可选地从RAM中删除图像变量。这是通过调用torch.cuda.empty_cache()
来清除未使用的GPU内存缓存,使用del
从内存中删除img
和pred
变量,如果image_
parameter设置为True,则可以
选择删除image
变量来完成的。
准备模型
# Load Pre-Trained Faster RCNN Model
model_ = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model_.eval() # Set the model to evaluation mode
# Disable gradient computation for all parameters
for name, param in model_.named_parameters():
param.requires_grad = False
print("Model loaded successfully.")
接下来,我们加载Faster R-CNN(基于区域的卷积神经网络)模型,该模型具有ResNet-50骨干和特征金字塔网络(FPN),该网络已针对对象检测任务进行了预训练。使用带有pretrained=True
参数的 torchvision.models.detection.fasterrcnn_resnet50_fpn()
函数加载模型,这意味着模型的权重使用预训练版本初始化。然后使用model_.eval()
将模型设置为评估模式,以确保层在推理过程中的行为是确定的。代码还通过为每个参数设置requires_grad= False
冻结模型的权重,从而防止在推理期间更新。
# Function to get predictions from the model
def model(x):
with no_grad():
yhat = model_(x)
return yhat
上面的模型
函数用于从预训练的对象检测模型生成预测。
# COCO class names
COCO_INSTANCE_CATEGORY_NAMES = [
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
'traffic light', 'fire hydrant', 'N/A', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', 'handbag',
'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop',
'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]
上面的列表COCO_CATEGORY_NAMES
包含COCO(上下文中的公共对象)数据集中使用的类名。
模型检测
一旦模型准备就绪,我们将通过预测各种图像中的对象来测试Faster R-CNN模型。下面是我们将执行对象检测的一些图像示例:
- 在包含一个人的图像上进行检测。
- 在包含多个人的图像上进行检测。
- 检测一只猫和一只狗。
- 检测一辆汽车和一架飞机。
- 从互联网上下载的图像上的检测。
# 1. Predicting a Person
img_path = '/kaggle/input/sample-images-for-object-detection/ronaldo.jpg'
image = Image.open(img_path)
image = image.resize([int(0.5 * s) for s in image.size])
plt.imshow(image)
plt.show()
transform = transforms.Compose([transforms.ToTensor()])
img = transform(image)
pred = model([img])
pred_class = get_predictions(pred, objects=["person"])
draw_box(pred_class, img)
save_RAM(image_=True)
# 2. Predicting People
img_path = '/kaggle/input/sample-images-for-object-detection/people.jpg'
image = Image.open(img_path)
image = image.resize([int(0.5 * s) for s in image.size])
plt.imshow(image)
plt.show()
img = transform(image)
pred = model([img])
pred_thresh = get_predictions(pred, threshold=0.9, objects=["person"])
draw_box(pred_thresh, img, rect_th=1, text_size=1, text_th=1)
save_RAM(image_=True)
# 3. Predicting Cat and Dog
img_path = '/kaggle/input/sample-images-for-object-detection/catanddog.jpg'
image = Image.open(img_path)
image = image.resize([int(0.5 * s) for s in image.size])
plt.imshow(image)
plt.show()
img = transform(image)
pred = model([img])
pred_thresh = get_predictions(pred, threshold=0.8)
draw_box(pred_thresh, img, rect_th=10, text_size=10, text_th=10)
save_RAM(image_=True)
# 4. Predicting a Car and a Plane
img_path = '/kaggle/input/sample-images-for-object-detection/carandplane.jpg'
image = Image.open(img_path)
image is resized to [int(0.5 * s) for s in image size]
plt.imshow(image)
plt.show()
img = transform(image)
pred = model([img])
pred_thresh = get_predictions(pred, threshold=0.9)
draw_box(pred_thresh, img)
save_RAM(image_=True)
# 5. Predicting on an Uploaded Image
url = 'https://www.plastform.ca/wp-content/themes/plastform/images/slider-image-2.jpg'
image = Image.open(requests.get(url, stream=True).raw).convert('RGB')
plt.imshow(image)
plt.show()
img = transform(image)
pred = model([img])
pred_thresh = get_predictions(pred, threshold=0.95)
draw_box(pred_thresh, img, rect_th=2, text_size=1.5, text_th=2)
save_RAM(image_=True)
结论
在本文中,我们探讨了如何使用Faster R-CNN模型进行对象检测。我们在这里所做的只是一小部分,还有很多进一步的探索,你可以自己尝试。这可能包括在不同的数据上尝试,调整阈值,或微调模型以检测更多的对象。希望这篇文章对大家有帮助,谢谢大家!
好了,这篇文章就介绍到这里,喜欢的小伙伴感谢给点个赞和关注,更多精彩内容持续更新~~
关于本篇文章大家有任何建议或意见,欢迎在评论区留言交流!