1 内容介绍
使用pytorch官方fasterrcnn预训练模型,在本机使用摄像头实现目标检测。fasterrcnn使用介绍:https://pytorch.org/vision/stable/models/generated/torchvision.models.detection.fasterrcnn_resnet50_fpn.html
2 代码实现
运行环境:torch1.13.1+cu117
import torch
import torchvision
import cv2 as cv
2.1 模型下载
- pretrained=True:使用预训练权重
- weights=torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights.COCO_V1:使用在COCO数据集上的与训练权重
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
pretrained=True,
weights=torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights.COCO_V1
)
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)
model.to(device)
COCO标签信息:
coco_labels_name = ["unlabeled", "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat","traffic light", "fire hydrant", "street sign", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse","sheep", "cow", "elephant", "bear", "zebra", "giraffe", "hat", "backpack", "umbrella", "shoe", "eye glasses", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports_ball", "kite", "baseball bat","baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "plate", "wine glass", "cup", "fork", "knife","spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot_dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "mirror", "dining table", "window", "desk", "toilet", "door", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "blender", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush", "hair brush"]
2.2 模型预测
根据官方说明,模型预测输出结果如下:
分别包含:
- boxes:左上坐标与右下坐标
- labels:标签信息
- scores:置信度
首先给出主体代码,使用opencv开启摄像头,将每一帧送入模型进行预测,返回模型预测结果,并进行处理
model.eval()
cap = cv.VideoCapture(0)
if not cap.isOpened():
print("Cannot open camera")
exit()
# 设置摄像头大小
cap.set(3, 800)
cap.set(4, 500)
while True:
# 逐帧捕获
ret, frame = cap.read()
# 如果正确读取帧,ret为True
if not ret:
print("Can't receive frame (stream end?). Exiting ...")
break
# 转化为张量类型,归一化到[0...1]区间
tensor_frame = transform(frame)
# 将[C, H, W]转化为[1, C, H, W]
depth = frame.shape[2]
width = frame.shape[0]
height = frame.shape[1]
tensor_frame = torch.reshape(tensor_frame, (1, depth, width, height))
tensor_frame = tensor_frame.to(device)
# 得到模型输出
pred = model(tensor_frame)
# 对pred进行截断,保留confidence > threshold的部分
threshold = 0.5
pred = save_threshold_item(pred, threshold)
# 在图像上绘制box,label,confidence
[frame] = plot(pred, [frame])
# 显示结果帧
cv.imshow('frame', frame)
if cv.waitKey(1) == ord('q'):
break
# 完成所有操作后,释放捕获器
cap.release()
cv.destroyAllWindows()
save_threshold_item()
为保留大于某个阈值的pred的函数:
def save_threshold_item(pred, threshold):
"""
保留置信度>thresh_hold的pred
:param pred: 模型输出
:param threshold: 阈值
:return: 返回截取后的pred
"""
threshold_index = 0
for i in range(len(pred)):
for j, score in enumerate(pred[i]['scores']):
if score <= threshold:
threshold_index = j
break
pred[i]['boxes'] = pred[i]['boxes'][:threshold_index]
pred[i]['labels'] = pred[i]['labels'][:threshold_index]
pred[i]['scores'] = pred[i]['scores'][:threshold_index]
return pred
plot()
为绘制预测框的函数:
def plot(pred, images):
"""
在图像上绘制box,label,confidence
:param pred: 模型输出
:param image: opencv读取的图像,可以用列表装入多张图像
:return:
"""
for i, image in enumerate(images):
boxes = pred[i]['boxes']
labels = pred[i]['labels']
scores = pred[i]['scores']
for i, item in enumerate(zip(boxes, labels, scores)):
xmin = int(item[0][0])
ymin = int(item[0][1])
xmax = int(item[0][2])
ymax = int(item[0][3])
# 绘制box
cv.rectangle(img=image, pt1=[xmin, ymin], pt2=[xmax, ymax], color=[255, 0, 0], thickness=2)
# 取出置信度
confidence = item[2].cpu().detach().clone().numpy()
confidence = round(float(confidence), 2)
text = coco_labels_name[int(item[1].cpu().numpy())] + " " + str(confidence)
font = cv.FONT_HERSHEY_SIMPLEX
cv.putText(image, text, (xmin, ymin), font, 0.8, (0, 0, 255), 1, cv.LINE_AA)
return images