点击下方卡片,关注“小白玩转Python”公众号
机器学习中的目标检测任务涉及识别图像或视频中特定类别(如人、汽车或动物)的实例,然后通过在它们周围绘制边界框来准确定位这些实例。让我们快速尝试一个模型:我们将在图像中检测猫:
from transformers import pipeline
model = pipeline("object-detection")
result = model("cat.jpg")
result
"""
[{'score': 0.9988692402839661,
'label': 'cat',
'box': {'xmin': 854, 'ymin': 499, 'xmax': 4094, 'ymax': 2797}}]
"""
我们初始化了一个目标检测管道。result
是输出。
score
:表示模型对其预测的置信度。在这里,置信度约为 99.89%,表明模型非常自信检测到的对象是一只猫。label
:指定检测到的对象类别。模型将该对象标记为“cat”。box
:包含识别图像中检测到的对象位置的边界框坐标。边界框由以下坐标定义:xmin
和ymin
:边界框左上角的坐标。xmax
和ymax
:边界框右下角的坐标。
from PIL import Image, ImageDraw
image_path = "cat.jpg"
image = Image.open(image_path)
box = result[0]["box"]
draw = ImageDraw.Draw(image)
bounding_box = (box['xmin'], box['ymin'], box['xmax'], box['ymax'])
draw.rectangle(bounding_box, outline="red", width=10)
image.show()
让我们再试一个:
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
from PIL import Image
import requests
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
# you can specify the revision tag if you don't want the timm dependency
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
# convert outputs (bounding boxes and class logits) to COCO API
# let's only keep detections with score > 0.9
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
print(
f"Detected {model.config.id2label[label.item()]} with confidence "
f"{round(score.item(), 3)} at location {box}"
)
"""
Detected remote with confidence 0.998 at location [40.16, 70.81, 175.55, 117.98]
Detected remote with confidence 0.996 at location [333.24, 72.55, 368.33, 187.66]
Detected couch with confidence 0.995 at location [-0.02, 1.15, 639.73, 473.76]
Detected cat with confidence 0.999 at location [13.24, 52.05, 314.02, 470.93]
Detected cat with confidence 0.999 at location [345.4, 23.85, 640.37, 368.72]
"""
DetrImageProcessor
和 DetrForObjectDetection
是从 transformers 库中导入的类,专门用于处理图像并使用 DETR 模型进行目标检测。
图像经过处理,转换为适合模型的格式,转换为张量(return_tensors="pt"
表示 PyTorch 张量)。然后,模型进行目标检测,返回包括边界框和类别 logits(归一化前的原始分数)的输出。
结果经过后处理,将模型输出转换为更易于使用的格式,包括通过置信度阈值(threshold=0.9
)过滤检测结果。只保留得分高于 0.9 的检测结果。
from PIL import ImageDraw
draw = ImageDraw.Draw(image)
detected_objects = [
{"label": "remote", "score": 0.998, "box": [40.16, 70.81, 175.55, 117.98]},
{"label": "remote", "score": 0.996, "box": [333.24, 72.55, 368.33, 187.66]},
{"label": "couch", "score": 0.995, "box": [-0.02, 1.15, 639.73, 473.76]},
{"label": "cat", "score": 0.999, "box": [13.24, 52.05, 314.02, 470.93]},
{"label": "cat", "score": 0.999, "box": [345.4, 23.85, 640.37, 368.72]}
]
for obj in detected_objects:
box = obj['box']
label = obj['label']
score = obj['score']
draw.rectangle(box, outline="red", width=2)
text = f"{label} {score:.3f}"
draw.text((box[0], box[1] - 10), text, fill="red")
image.show()
有各种适用于不同类型对象的目标检测模型。例如,模型可以检测 PDF 文档中的表格。
from huggingface_hub import hf_hub_download
from transformers import AutoImageProcessor, TableTransformerForObjectDetection
import torch
from PIL import Image
file_path = hf_hub_download(repo_id="nielsr/example-pdf", repo_type="dataset", filename="example_pdf.png")
image = Image.open(file_path).convert("RGB")
image_processor = AutoImageProcessor.from_pretrained("microsoft/table-transformer-detection")
model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-transformer-detection")
inputs = image_processor(images=image, return_tensors="pt")
outputs = model(**inputs)
# convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
target_sizes = torch.tensor([image.size[::-1]])
results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[
0
]
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
print(
f"Detected {model.config.id2label[label.item()]} with confidence "
f"{round(score.item(), 3)} at location {box}"
)
"""
Detected table with confidence 1.0 at location [202.1, 210.59, 1119.22, 385.09]
"""
· END ·
HAPPY LIFE
本文仅供学习交流使用,如有侵权请联系作者删除