系统:Windows 11
pytorch版本:1.11.0
torchvision版本:0.12.0
使用图片地址:vision/person1.jpg at main · pytorch/vision · GitHub
使用代码:
import torch
import torchvision.transforms
from torchvision.models.detection import keypointrcnn_resnet50_fpn
from torchvision.io import read_image
# person_int = read_image(str(Path("assets") / "person1.jpg"))
person_int = read_image(r"E:\git_rep\vision-0.12.0\gallery\assets\person1.jpg")
transforms1 = torchvision.transforms.ToPILImage()
transforms2 = torchvision.transforms.ToTensor()
person_float = transforms1(person_int)
person_float = transforms2(person_float)
model = keypointrcnn_resnet50_fpn(True, progress=False)
model = model.eval()
outputs = model([person_float])
print(outputs)
kpts = outputs[0]['keypoints']
scores = outputs[0]['scores']
print(kpts)
print(scores)
detect_threshold = 0.75
idx = torch.where(scores > detect_threshold)
keypoints = kpts[idx]
print(keypoints)
import torch
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F
plt.rcParams["savefig.bbox"] = 'tight'
def show(imgs):
if not isinstance(imgs, list):
imgs = [imgs]
fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
for i, img in enumerate(imgs):
img = img.detach()
img = F.to_pil_image(img)
axs[0, i].imshow(np.asarray(img))
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
from torchvision.utils import draw_keypoints
res = draw_keypoints(person_int, keypoints, colors="blue", radius=3)
show(res)
结果:
参考文献:
1.keypointrcnn_resnet50_fpn — Torchvision 0.13 documentation
2.Visualization utilities — Torchvision 0.13 documentation
3.Models and pre-trained weights — Torchvision 0.13 documentation