def predict(sample):
"""
参数:
sample -- 手写字符像素值,列表
返回:
pred -- 手写字符预测标签,整型
"""
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets, svm, metrics
from sklearn.model_selection import train_test_split
digits = datasets.load_digits()
### 补充代码 ###
#display sample picture
# sample: img_vector=>img
sample_img=np.array(sample).reshape(8, 8)
# plt.gray()
# plt.matshow(sample_img)
# plt.show()
# _, axes = plt.subplots(nrows=1, ncols=4, figsize=(10, 3))
# # for ax, image, label in zip(axes, digits.images, digits.target):
# # ax.set_axis_off()
# # ax.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
# # ax.set_title('Training: %i' % label)
# # plt.show()
#training
# flatten the images
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))
# Create a classifier: a support vector classifier
clf = svm.SVC(gamma=0.001)
# Split data into 50% train and 50% test subsets
X_train, X_test, y_train, y_test = train_test_split(
data, digits.target, test_size=0.5, shuffle=False)
#print(sample_img.reshape(1,64).shape)
# Learn the digits on the train subset
clf.fit(X_train, y_train)
pred= clf.predict(sample_img.reshape(1,64))
return pred
if __name__=="__main__":
# 示例测试代码
sample = [0., 0., 6., 14., 4., 0., 0., 0.,
0., 0., 11., 16., 10., 0., 0., 0.,
0., 0., 8., 14., 16., 2., 0., 0.,
0., 0., 1., 12., 12., 11., 0., 0.,
0., 0., 0., 0., 0., 11., 3., 0.,
0., 0., 0., 0., 0., 5., 11., 0.,
0., 0., 1., 4., 4., 7., 16., 2.,
0., 0., 7., 16., 16., 13., 11., 1.]
print(predict(sample))