pytorch使用fasterrcnn实现目标检测

1 内容介绍

使用pytorch官方fasterrcnn预训练模型,在本机使用摄像头实现目标检测。fasterrcnn使用介绍:https://pytorch.org/vision/stable/models/generated/torchvision.models.detection.fasterrcnn_resnet50_fpn.html

2 代码实现

运行环境:torch1.13.1+cu117

import torch
import torchvision

import cv2 as cv

2.1 模型下载

  • pretrained=True:使用预训练权重
  • weights=torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights.COCO_V1:使用在COCO数据集上的与训练权重
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
    pretrained=True,
    weights=torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights.COCO_V1
)

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)
model.to(device)

COCO标签信息:

coco_labels_name = ["unlabeled", "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat","traffic light", "fire hydrant", "street sign", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse","sheep", "cow", "elephant", "bear", "zebra", "giraffe", "hat", "backpack", "umbrella", "shoe", "eye glasses", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports_ball", "kite", "baseball bat","baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "plate", "wine glass", "cup", "fork", "knife","spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot_dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "mirror", "dining table", "window", "desk", "toilet", "door", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "blender", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush", "hair brush"]

2.2 模型预测

根据官方说明,模型预测输出结果如下:
预测结果
分别包含:

  • boxes:左上坐标与右下坐标
  • labels:标签信息
  • scores:置信度

首先给出主体代码,使用opencv开启摄像头,将每一帧送入模型进行预测,返回模型预测结果,并进行处理

model.eval()

cap = cv.VideoCapture(0)

if not cap.isOpened():
    print("Cannot open camera")
    exit()

# 设置摄像头大小
cap.set(3, 800)
cap.set(4, 500)

while True:
    # 逐帧捕获
    ret, frame = cap.read()
    # 如果正确读取帧,ret为True
    if not ret:
        print("Can't receive frame (stream end?). Exiting ...")
        break

    # 转化为张量类型,归一化到[0...1]区间
    tensor_frame = transform(frame)

    # 将[C, H, W]转化为[1, C, H, W]
    depth = frame.shape[2]
    width = frame.shape[0]
    height = frame.shape[1]
    tensor_frame = torch.reshape(tensor_frame, (1, depth, width, height))
    tensor_frame = tensor_frame.to(device)

    # 得到模型输出
    pred = model(tensor_frame)

    # 对pred进行截断,保留confidence > threshold的部分
    threshold = 0.5
    pred = save_threshold_item(pred, threshold)

    # 在图像上绘制box,label,confidence
    [frame] = plot(pred, [frame])

    # 显示结果帧
    cv.imshow('frame', frame)
    if cv.waitKey(1) == ord('q'):
        break

# 完成所有操作后,释放捕获器
cap.release()
cv.destroyAllWindows()

save_threshold_item()为保留大于某个阈值的pred的函数:

def save_threshold_item(pred, threshold):
    """
    保留置信度>thresh_hold的pred

    :param pred: 模型输出
    :param threshold: 阈值
    :return: 返回截取后的pred
    """
    threshold_index = 0
    for i in range(len(pred)):
        for j, score in enumerate(pred[i]['scores']):
            if score <= threshold:
                threshold_index = j
                break
        pred[i]['boxes'] = pred[i]['boxes'][:threshold_index]
        pred[i]['labels'] = pred[i]['labels'][:threshold_index]
        pred[i]['scores'] = pred[i]['scores'][:threshold_index]

    return pred

plot()为绘制预测框的函数:

def plot(pred, images):

    """
    在图像上绘制box,label,confidence
    :param pred: 模型输出
    :param image: opencv读取的图像,可以用列表装入多张图像
    :return:
    """

    for i, image in enumerate(images):

        boxes = pred[i]['boxes']
        labels = pred[i]['labels']
        scores = pred[i]['scores']
        for i, item in enumerate(zip(boxes, labels, scores)):

            xmin = int(item[0][0])
            ymin = int(item[0][1])
            xmax = int(item[0][2])
            ymax = int(item[0][3])

            # 绘制box
            cv.rectangle(img=image, pt1=[xmin, ymin], pt2=[xmax, ymax], color=[255, 0, 0], thickness=2)

            # 取出置信度
            confidence = item[2].cpu().detach().clone().numpy()
            confidence = round(float(confidence), 2)
            text = coco_labels_name[int(item[1].cpu().numpy())] + " " + str(confidence)
            font = cv.FONT_HERSHEY_SIMPLEX
            cv.putText(image, text, (xmin, ymin), font, 0.8, (0, 0, 255), 1, cv.LINE_AA)

    return images

3 运行结果

result

4 完整代码

https://github.com/gwcrepo/pytorch-fasterrcnn_resnet50_fpn

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值