本文通过在Python3.7下,利用sklearn+pca+svm对笑脸和不笑脸进行分类,图片的大小为55×55。
注:本文使用的数据集一共有42张图片,其中36张作为训练集,剩下6张作为测试集。其中笑的和不笑的都为21张。
此代码经过测试,是可以运行的,如果连百度网盘链接过期了,在评论区留下你的邮箱,我给你发过去。
链接:https://pan.baidu.com/s/1K3U8er-C8ThUBkuP2_ac5w
提取码:1iy3
复制这段内容后打开百度网盘手机App,操作更方便哦
代码如下:
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler # 标准化处理
from sklearn.decomposition import PCA # PCA
from sklearn.svm import SVC # SVM
from PIL import Image
'''
@author:xiao黄
time:2020-11-16
'''
Image_size = 55
'''
注意:自己测试的时候需要修改两个文件路径,就是你存放数据的路径
'''
def load_data():
'''
加载数据集
:return 训练数据和对应的标签 0 为笑脸,1为不笑
'''
train_data = []
train_label = []
file_path = 'C:\\Users\\HKZ\\Desktop\\2020\\faceImages\\train\\' # 训练集
for i in range(1,37):
path = file_path + str(i) + '.png'
image = Image.open(path).convert('L') # 读取图片
# if i == 1:
# print(image.shape())
img = np.reshape(image,(1,Image_size*Image_size)) # 转换成一维
train_data.extend(img)
if i <= 16:
train_label.append(0)
else:
train_label.append(1)
train_data = np.reshape(train_data,(36,Image_size*Image_size))
train_label = np.matrix(train_label).T
test_data = []
test_label = []
file_path_test = 'C:\\Users\\HKZ\\Desktop\\2020\\faceImages\\test\\' # 测试集
for i in range(1,7):
path = file_path_test + str(i) + '.png'
image = Image.open(path).convert('L') # 读取图片
img = np.reshape(image,(1,Image_size*Image_size)) # 转换成一维
test_data.extend(img)
if i <= 3:
test_label.append(0)
else:
test_label.append(1)
test_data = np.reshape(test_data,(6,Image_size*Image_size))
return np.matrix(train_data), train_label, np.matrix(test_data), test_label
def svm(trainDataSimplified, trainLabel, testDataSimplified):
clf2 = SVC(C=2.0) # C为分类数目
clf2.fit(trainDataSimplified, trainLabel) # 训练模型
return clf2.predict(testDataSimplified)
if __name__ == "__main__":
train_data,train_label,test_data,test_label = load_data()
pca = PCA(0.9, True, True) # 建立pca类,设置参数,保留90%的数据方差
# pca = PCA(n_components=15) # 降到15维
trainData = pca.fit_transform(train_data) # 拟合并降维训练数据
testData = pca.transform(test_data) # 降维测试数据
result = svm(trainData,train_label,testData)
print('预测结果',result)
print('测试集标签',test_label)
此代码经过测试,是可以运行的,如果连百度网盘链接过期了,在评论区留下你的邮箱,我给你发过去。