机器学习:KNN完成英文手语分类

机器学习:KNN完成英文手语分类

先来看一下标准的英文手语:
在这里插入图片描述

本文KNN网络识别部分效果:(当然只挑选了几个字母):
在这里插入图片描述在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

完成的主要流程为:

  1. 通过mediapipe处理kaggle上的一个图片数据集,获取21个标志点的位置关系
  2. 建立KNN模型
  3. 模型检测

建立数据集

图片数据集来源Kaggle上的一个数据集,数据集链接
下载完成后我们得到是一堆128*128的图像,不会DL的菜鸡只能借助mediapipe完成标志点信息的提取(mediapipe的相关分装在另一篇文章中有介绍,也可见文末)。
处理流程为:

  1. 遍历图片数据集,利用mediapipe完成手势信息的提取(为提高准确度,可调节置信区间,在另一篇文章中有介绍),也就是21个坐标点的xy坐标。
  2. 数据集的提供者将图片文件的最后一个字母设置为手势的含义,我们可以将他作为我们的目标值保存到文件中。
import pandas as pd
import HandTrackingModule as htm
import cv2
import os
import time
import numpy as np
import csv


detector = htm.handDetctor(mode=True, detectionCon=0.6, trackCon=0.6)
csv_col_name = ['0_x', '0_y', '1_x', '1_y', '2_x', '2_y', '3_x', '3_y', '4_x', '4_y', '5_x', '5_y',
                '6_x', '6_y', '7_x', '7_y', '8_x', '8_y', '9_x', '9_y', '10_x', '10_y', '11_x', '11_y',
                '12_x', '12_y', '13_x', '13_y', '14_x', '14_y', '15_x', '15_y', '16_x', '16_y', '17_x', '17_y',
                '18_x', '18_y', '19_x', '19_y', '20_x', '20_y','target']


def load_image():

    path = "dataset5"
    dirs = os.listdir(path)
    for file_ABCD in dirs:
        for file_abcd in os.listdir(path+"/"+file_ABCD):
            for img_path in os.listdir(path+"/"+file_ABCD+"/"+file_abcd):
                # 跳过暗的图片
                if "depth" in img_path:
                    continue
                # print(path+"/"+file_ABCD+"/"+file_abcd+"/"+img_path)
                img = cv2.imread(path+"/"+file_ABCD+"/"+file_abcd+"/"+img_path)
                # 对图像进行处理
                img = detector.findHands(img)
                lmList = detector.findPosition(img, draw=False)
                if len(lmList) == 42:
                    # print(lmList)
                    # print(file_abcd)
                    lmList.append(file_abcd)
                    # 将特征点写入csv文件中
                    write_to_csv(lmList)
                cv2.imshow("show", img)
                cv2.waitKey(1)


def write_to_csv(lmList):
    # test = pd.DataFrame(columns=csv_col_name, data=[lmList])
    # test.to_csv('testcsv.csv', index=True)
    with open(r'testcsv.csv', mode='a', newline='', encoding='utf8') as cfa:
        csv_write = csv.writer(cfa)
        csv_write.writerow(lmList)
    return None


if __name__ == '__main__':
    load_image()

我们将数据保存到testcsv.csv文件中,如下图:
在这里插入图片描述

其实这数据量有点小

建立模型以及评估

建立模型:

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV
import joblib


def load_data():
    data = pd.read_csv("testcsv.csv")
    # print(data.iloc[:, 0:42])
    # 划分数据集
    x_train, x_test, y_train, y_test = train_test_split(data.iloc[:, 0:42], data.target, random_state=10)

    # 标准化
    transfer = StandardScaler()
    x_train = transfer.fit_transform(x_train)
    # 训练集和测试集做相同处理(很重要!)
    x_test = transfer.transform(x_test)

    # KNN算法预估器  建立模型
    estimator = KNeighborsClassifier(n_neighbors=10)
    # 添加网格搜索交叉验证
    param_dict = {"n_neighbors": [11, 13, 15, 17, 19, 21]}
    estimator = GridSearchCV(estimator, param_grid=param_dict, cv=10)

    estimator.fit(x_train, y_train)

    # 模型评估
    # 1 直接对比真实值和预估值
    y_predict = estimator.predict(x_test)
    print(y_predict == y_test)

    # 计算准确率
    score = estimator.score(x_test, y_test)
    print(score)

    # 保存模型
    joblib.dump(estimator, "k_near.pkl")

实时检测:


import HandTrackingModule as htm
import cv2
import numpy as np

wCam, hCam = 640, 480
cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
cap.set(3, wCam)
cap.set(4, hCam)

detector = htm.handDetctor(detectionCon=0.6, trackCon=0.6)
model = joblib.load("k_near.pkl")

while True:
    success, img = cap.read()

    img = detector.findHands(img)
    lmList = detector.findPosition(img, draw=False)
    if len(lmList) == 42:
        lm = transfer.transform(np.array(lmList).reshape(1, -1))
        m_predict = model.predict(lm)
        cv2.putText(img, str(m_predict), (10, 70), cv2.FONT_HERSHEY_PLAIN, 3, (255, 0, 255), 3)
    cv2.imshow("image", img)
    if cv2.waitKey(2) & 0xFF == 27:
        break

可优化点

  1. KNN的参数和网格搜索的范围都可以再进行优化。
  2. KNN本身只适用于小数据场景,且对K值敏感,总体手势识别效果一般。
  3. 数据预处理部分极为粗糙,可进行数据清洗等操作

考研了,也懒得整了

HandTrackingModule.py

import cv2
import mediapipe as mp
import time
import math


class handDetctor():
    def __init__(self, mode=False, maxHands=2, detectionCon=0.5, trackCon=0.5):
        self.mode = mode
        self.maxHands = maxHands
        self.detectionCon = detectionCon
        self.trackCon = trackCon

        self.mpHands = mp.solutions.hands
        self.hands = self.mpHands.Hands(self.mode, self.maxHands,
                                        self.detectionCon, self.trackCon)
        self.mpDraw = mp.solutions.drawing_utils

    def findHands(self, img, draw=True, ):
        imgRGB = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)#转换为rgb
        self.results = self.hands.process(imgRGB)

        # print(results.multi_hand_landmarks)
        if self.results.multi_hand_landmarks:
            for handLms in self.results.multi_hand_landmarks:
                if draw:
                    self.mpDraw.draw_landmarks(img, handLms, self.mpHands.HAND_CONNECTIONS)

        return img

    def findPosition(self, img, handNo=0, draw=True):
        lmList = []
        if self.results.multi_hand_landmarks:
            myHand = self.results.multi_hand_landmarks[handNo]
            for id, lm in enumerate(myHand.landmark):
                # print(id, lm)
                # 获取手指关节点
                h, w, c = img.shape
                # cx, cy = int(lm.x*w), int(lm.y*h)
                lmList.append(lm.x)
                lmList.append(lm.y)
                # if draw:
                #     cv2.putText(img, str(int(id)), (cx+10, cy+10), cv2.FONT_HERSHEY_PLAIN,
                #                 1, (0, 0, 255), 2)

        return lmList

    def fingerStatus(self, lmList):
    # 返回列表 包含每个手指的开合状态
        fingerList = []
        id, originx, originy = lmList[0]
        keypoint_list = [[2, 4], [6, 8], [10, 12], [14, 16], [18, 20]]
        for point in keypoint_list:
            id, x1, y1 = lmList[point[0]]
            id, x2, y2 = lmList[point[1]]
            if math.hypot(x2-originx, y2-originy) > math.hypot(x1-originx, y1-originy):
                fingerList.append(True)
            else:
                fingerList.append(False)

        return fingerList

  • 1
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 9
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Zccccccc_tz

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值