from sklearn.preprocessing import LabelEncoder
from sklearn.svm import LinearSVC
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import train_test_split
from IPython.display import Image
from imutils import paths
import numpy as np
import cv2
import os
def extract_histogram(image, bins=(8, 8, 8)):
hist = cv2.calcHist([image], [0, 1, 2], None, bins, [0, 256, 0, 256, 0, 256])
cv2.normalize(hist, hist)
return hist.flatten()
imagePaths = sorted(list(paths.list_images('./data/train')))
data = []
labels = []
for (i, imagePath) in enumerate(imagePaths):
image = cv2.imread(imagePath, 1)
label = imagePath.split(os.path.sep)[-1].split(".")[0]
hist = extract_histogram(image)
data.append(hist)
labels.append(label)
le = LabelEncoder()
labels = le.fit_transform(labels)
(trainData, testData, trainLabels, testLabels) = train_test_split(np.array(data), labels, test_size=0.25, random_state=2)
model = LinearSVC(random_state = 2, C = 0.94)
model.fit(trainData, trainLabels)
#1
print(np.round(model.coef_[0][280],2))
#2
print(np.round(model.coef_[0][129],2))
#3
print(np.round(model.coef_[0][440],2))
#4
from sklearn.metrics import f1_score
predictions = model.predict(testData)
print(np.round(f1_score(testLabels, predictions, average='macro'),2))
#print(classification_report(testLabels, predictions, target_names=le.classes_))
#5
singleImage = cv2.imread('./data/test/cat.1016.jpg')
histt = extract_histogram(singleImage)
histt2 = histt.reshape(1, -1)
prediction = model.predict(histt2)
print(prediction)
#6
singleImage = cv2.imread('./data/test/cat.1024.jpg')
histt = extract_histogram(singleImage)
histt2 = histt.reshape(1, -1)
prediction = model.predict(histt2)
print(prediction)
#7
singleImage = cv2.imread('./data/test/dog.1006.jpg')
histt = extract_histogram(singleImage)
histt2 = histt.reshape(1, -1)
prediction = model.predict(histt2)
print(prediction)
#8
singleImage = cv2.imread('./data/test/dog.1033.jpg')
histt = extract_histogram(singleImage)
histt2 = histt.reshape(1, -1)
prediction = model.predict(histt2)
print(prediction)
数据集下载地址:Dogs vs. Cats | Kaggle