基于 YOLOv8n-pose 模型的图像特征提取,可用于识别特定的姿态

目录

1. __init__ 方法:初始化类的实例

2. save_pose_feat 方法:

3. load_db_pose_feat 方法:

4. cal_similarity 方法:


实现了一个基于 YOLOv8n-pose 模型的图像特征提取和相似性比较系统。它可以从图像中提取人体关键点信息,并将其保存为特征文件。然后,通过计算输入图像与数据库中图像特征的相似度,确定输入图像的类别。

1. __init__ 方法:初始化类的实例

加载 YOLOv8n-pose 模型并加载数据库中的姿态特征。
load_model 方法:加载 YOLOv8n-pose 模型。

 def load_model(self):
        model=YOLO('yolov8n-pose.pt')
        return model


extract_fact 方法:从输入图像中提取特征,包括目标框的坐标和人体关键点的归一化坐标。相当于是17个点相对于特定的框的位置做了归一化

    def extract_fact(self,img_path):
        list=[]
        result=self.model(img_path)
        x1,y1,x2,y2,conf,cls=result[0].boxes.data[0]
        x1,y1,x2,y2=x1.item(),y1.item(),x2.item(),y2.item()
        for x_y in result[0].keypoints.xy[0]:
            x,y=x_y
            x=x.item()
            y=y.item()
            x=(x-x1)/(x2-x1)
            y=(y-y2)/(y1-y2)
            list.append(x)
            list.append(y)
        # print(list)
        return list


2. save_pose_feat 方法:

从指定目录下的图像中提取姿态特征,并保存到文本文件中。

 def save_pose_feat(self):
        img_paths=glob.glob('image_arm\*\*')
        with open('feature.txt','w',encoding='utf-8') as f:
            for img_path in img_paths:
                img_name=img_path.split('\\')[-2]+' '
                list=self.extract_fact(img_path)
                f.write(img_name)
                list=str(list)
                f.write(list)
                f.write('\n')


3. load_db_pose_feat 方法:

从保存的特征文件中加载数据库中的图像名称和特征。

    def load_db_pose_feat(self):
        with open('feature.txt','r',encoding='utf-8') as f:
            lines=f.readlines()
            db_names=[]
            db_features=[]
            for line in lines:
                db_name=line.split(' ')[0]
                db_feature=line.split(' ',1)[1]
                db_feature=json.loads(db_feature)
                db_names.append(db_name)
                db_features.append(db_feature)
        return db_names,db_features


4. cal_similarity 方法:

计算输入图像与数据库中图像特征的相似度,并确定输入图像的类别。

 如果找出来最相似的三张图片是一样的,那么就可以成功预测出来

 def cal_similarity(self,img_path):#计算相似度
        db_names,db_features=self.db_names,self.db_features
        db_names=np.array(db_names)
        my_feature=self.extract_fact(img_path)
        db_features=np.array(db_features)
        my_feature=np.array(my_feature)
        dist=np.linalg.norm(my_feature-db_features,axis=1)
        stack_dist_name=np.column_stack((dist,db_names))
        sort_index=np.argsort(stack_dist_name[:,0])
        top3=stack_dist_name[sort_index][:3][:,1]
        top1=top3[0]
        count=0
        for i in top3[1:]:
            if i==top1:
                count+=1
        if count==2:
            print('类别是',top1)
        else:
            print('啥也不是')
        print()

完整代码如下:

import glob
import json
import os

import cv2
import numpy as np
from ultralytics import YOLO
class FrameFeat:
    def __init__(self):
        self.model=self.load_model()
        self.db_names,self.db_features=self.load_db_pose_feat()
    def load_model(self):
        model=YOLO('yolov8n-pose.pt')
        return model
    def extract_fact(self,img_path):
        list=[]
        result=self.model(img_path)
        x1,y1,x2,y2,conf,cls=result[0].boxes.data[0]
        x1,y1,x2,y2=x1.item(),y1.item(),x2.item(),y2.item()
        for x_y in result[0].keypoints.xy[0]:
            x,y=x_y
            x=x.item()
            y=y.item()
            x=(x-x1)/(x2-x1)
            y=(y-y2)/(y1-y2)
            list.append(x)
            list.append(y)
        # print(list)
        return list
    def save_pose_feat(self):
        img_paths=glob.glob('image_arm\*\*')
        with open('feature.txt','w',encoding='utf-8') as f:
            for img_path in img_paths:
                img_name=img_path.split('\\')[-2]+' '
                list=self.extract_fact(img_path)
                f.write(img_name)
                list=str(list)
                f.write(list)
                f.write('\n')
    def load_db_pose_feat(self):
        with open('feature.txt','r',encoding='utf-8') as f:
            lines=f.readlines()
            db_names=[]
            db_features=[]
            for line in lines:
                db_name=line.split(' ')[0]
                db_feature=line.split(' ',1)[1]
                db_feature=json.loads(db_feature)
                db_names.append(db_name)
                db_features.append(db_feature)
        return db_names,db_features
    def cal_similarity(self,img_path):#计算相似度
        db_names,db_features=self.db_names,self.db_features
        db_names=np.array(db_names)
        my_feature=self.extract_fact(img_path)
        db_features=np.array(db_features)
        my_feature=np.array(my_feature)
        dist=np.linalg.norm(my_feature-db_features,axis=1)
        stack_dist_name=np.column_stack((dist,db_names))
        sort_index=np.argsort(stack_dist_name[:,0])
        top3=stack_dist_name[sort_index][:3][:,1]
        top1=top3[0]
        count=0
        for i in top3[1:]:
            if i==top1:
                count+=1
        if count==2:
            print('类别是',top1)
        else:
            print('啥也不是')
        print()

if __name__ == '__main__':
    img_path=r'D:\AI_37\ultralytics-8.2.74\ultralytics\pos-detect\img_1.png'
    ff=FrameFeat()
    # ff.extract_fact(img_path)
    # ff.save_pose_feat()
    ff.load_db_pose_feat()
    ff.cal_similarity(img_path) 

数据的格式如下 

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

西柚与蓝莓

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值