原始图像
输出结果
实验代码
# -*- coding: utf-8 -*-
"""
faster rcnn实现目标检测
"""
import os
import time
import torch.nn as nn
import torch
import numpy as np
import torchvision.transforms as transforms
import torchvision
from PIL import Image
from matplotlib import pyplot as plt
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# classes_coco
COCO_INSTANCE_CATEGORY_NAMES = [
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]
if __name__ == "__main__":
# path_img = os.path.join(BASE_DIR, "demo_img1.png")
# path_img = os.path.join(BASE_DIR, "demo_img2.png")
path_img = os.path.join(BASE_DIR, "小姐姐们.jpg")
# config
preprocess = transforms.Compose([
transforms.ToTensor(),
])
# 1. load data & model
input_image = Image.open(path_img).convert("RGB")
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()
# 2. preprocess
img_chw = preprocess(input_image)
# 3. to device
if torch.cuda.is_available():
img_chw = img_chw.to('cuda')
model.to('cuda')
# 4. forward
# 这里图片不再是 BCHW 的形状,而是一个list,每个元素是图片
input_list = [img_chw]
with torch.no_grad():
tic = time.time()
print("input img tensor shape:{}".format(input_list[0].shape))
output_list = model(input_list)
# 输出也是一个 list,每个元素是一个 dict
output_dict = output_list[0]
print("pass: {:.3f}s".format(time.time() - tic))
for k, v in output_dict.items():
print("key:{}, value:{}".format(k, v))
# 5. visualization
out_boxes = output_dict["boxes"].cpu()
out_scores = output_dict["scores"].cpu()
out_labels = output_dict["labels"].cpu()
fig, ax = plt.subplots(figsize=(12, 12))
ax.imshow(input_image, aspect='equal')
num_boxes = out_boxes.shape[0]
# 这里最多绘制 40 个框
max_vis = 40
thres = 0.5
for idx in range(0, min(num_boxes, max_vis)):
score = out_scores[idx].numpy()
bbox = out_boxes[idx].numpy()
class_name = COCO_INSTANCE_CATEGORY_NAMES[out_labels[idx]]
# 如果分数小于这个阈值,则不绘制
if score < thres:
continue
ax.add_patch(plt.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0], bbox[3] - bbox[1], fill=False,
edgecolor='red', linewidth=3.5))
ax.text(bbox[0], bbox[1] - 2, '{:s} {:.3f}'.format(class_name, score), bbox=dict(facecolor='blue', alpha=0.5),
fontsize=20, color='white')
plt.show()
plt.close()
# appendix
classes_pascal_voc = ['__background__',
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor']
# classes_coco
COCO_INSTANCE_CATEGORY_NAMES = [
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]
参考:https://github.com/zhangxiann/PyTorch_Practice
个人自学,不做盈利,如有侵权,立刻删除。