1.首先进行数据处理
import numpy as np #用于数据处理
from matplotlib import pyplot as plt #用于显示图像和画图
from sklearn import svm #导入支持向量机
from sklearn.model_selection import train_test_split #用于数据集划分
from sklearn.metrics import accuracy_score #用于计算正确率
import cv2 #用于读取图片
import os #文件读取
import pickle #用于模型的保存
from PIL import Image
SHAPE = (30, 30) #设置输入图片的大小
1.1文件的结构如下
图片不用太多,一类几张即可
替换成自己的图片以及目录即可
def getImageData(self,directory):
s = 1
feature_list = list()
label_list = list()
num_classes = 0
for root, dirs, files in os.walk(directory):
for d in dirs:
num_classes += 1
images = os.listdir(root + d)
for image in images:
s += 1
label_list.append(d)
feature_list.append(Svm_derection.extractFeaturesFromImage(root + d + "/" + image))
return np.asarray(feature_list), np.asarray(label_list)
1.2接下来图片的预处理函数(上方有调用到)
def extractFeaturesFromImage(self,image_file):
img = cv2.imread(image_file)#读取图片
img = cv2.resize(img, self.SHAPE, interpolation=cv2.INTER_CUBIC)
#对图片进行risize操作统一大小
img = img.flatten()#对图像进行降维操作,方便算法计算
img = img / np.mean(img)#归一化,突出特征
return img
2.svm模型训练
def train(self,dir):
#数据获取,这里Svm_derection是自定义类的名称
feature_array, label_array = Svm_derection.getImageData(self.directory)
#数据的分割
X_train, X_test, y_train, y_test = train_test_split(feature_array, label_array, test_size=0.2, random_state=42)
print("shape of raw image data: {0}".format(feature_array.shape))
print("shape of raw image data: {0}".format(X_train.shape))
print("shape of raw image data: {0}".format(X_test.shape))
#模型的选择
clf = svm.SVC(gamma=0.001, C=100., probability=True)
#模型的训练
clf.fit(X_train, y_train);
#模型测试
Ypred = clf.predict(X_test);
print("pre",Ypred)
print("test",y_test)
#模型保存
pickle.dump(clf, open("svm.pkl", "wb"))
3.模型读取使用
def test(self,path,img_file):
pkl_file = open(path, 'rb')
clf=pickle.load(pkl_file)
Ypred = clf.predict(np.reshape(self.extractFeaturesFromImage(img_file),(1,2700)))
return Ypred
4.运行代码
path='svm.pkl'#模型保存位置以及名字
img='derection/'#数据集位置
img_file='derection/f1/1.jpg'#测试图片位置
train(img)
t=test(path,img_file)
print(t)
img = Image.open(os.path.join('derection/f1/1.jpg'))
plt.figure("Image") # 图像窗口名称
plt.imshow(img)
plt.axis('off') # 关掉坐标轴为 off
plt.title(t) # 图像题目
plt.show()