Linda 正在尝试使用 SIFT 算法和 k-means/Ward 聚类来构建图像检索系统和 SVM 分类器,希望能够识别猫的图片。她已经将猫的图片和非猫的图片收集起来,并使用 SIFT 算法提取了图片的特征。她使用 k-means 将这些特征聚类成 3 个中心,并使用这些中心作为图片的描述符。然后,她将这些描述符转换为稀疏文件,并使用这些文件训练 SVM 分类器。
2、解决方案
Andy 指出了 Linda 在方法中的一些错误:
- **数据分割:**在进行任何机器学习任务之前,首先应该将数据分割成训练集和测试集。这样才能评估分类器的性能。
- **特征选择:**对于词袋模型,不应该使用聚类中心作为图片的描述符。应该统计每一张图片中各个聚类中心出现的次数,并以此构建直方图作为描述符。
- **聚类中心的数量:**聚类中心的数量太少。3 个聚类中心无法提供足够的信息来区分猫和非猫的图片。应该尝试使用 100-500 个聚类中心。
- **样本数量:**样本数量太少。50 张图片对于训练 SVM 分类器来说太少了。应该尝试使用至少 50-100 张图片。
- **特征选择:**对于猫的图片,HOG 特征可能更适合作为描述符。
Andy 还建议使用 np.bincount + 除以总和的方法来构建直方图描述符。
代码示例:
import numpy as np
import vlfeat_module
from sklearn.svm import NuSVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# 数据集
cat_pictures = ['cat1.jpg', 'cat2.jpg', 'cat3.jpg']
noncat_pictures = ['noncat1.jpg', 'noncat2.jpg', 'noncat3.jpg']
# 特征提取
def get_features(datas):
list = []
for data in datas:
loc, des = vlfeat_module.vlf_create_desc(data, 'tmp.sift')
list.append(hstack((loc, des)))
desc = numpy.vstack(list)
return desc
# 聚类
def get_centers(desc, k):
center, _ = kmeans(desc, k)
return center
# 构建直方图描述符
def get_histograms(centers, des):
histograms = []
for img_des in des:
hist = np.bincount(kmeans(img_des, centers)[1])
histograms.append(hist / np.sum(hist))
return histograms
# 训练 SVM 分类器
def train_svm(histograms, labels):
X_train, X_test, y_train, y_test = train_test_split(histograms, labels, test_size=0.2)
clf = NuSVC(gamma=0.07, verbose=True)
clf.fit(X_train, y_train)
return clf
# 评估 SVM 分类器
def evaluate_svm(clf, X_test, y_test):
pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, pred)
return accuracy
# 加载数据
cat_des = get_features(cat_pictures)
noncat_des = get_features(noncat_pictures)
# 聚类
cat_centers = get_centers(cat_des, 100)
noncat_centers = get_centers(noncat_des, 100)
# 构建直方图描述符
cat_histograms = get_histograms(cat_centers, cat_des)
noncat_histograms = get_histograms(noncat_centers, noncat_des)
# 合并数据集
histograms = cat_histograms + noncat_histograms
labels = [1] * len(cat_histograms) + [0] * len(noncat_histograms)
# 训练 SVM 分类器
clf = train_svm(histograms, labels)
# 评估 SVM 分类器
accuracy = evaluate_svm(clf, X_test, y_test)
print('准确率:', accuracy)