SVM入门,对随机数据进行象限划分
下面的程序是使用svm对随机数据进行象限划分,运行环境为python3.7,不需要额外下载数据。
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm
import pylab as pl
from matplotlib.colors import ListedColormap
DATA_LENGTH = 1000 #数据个数,可以自己设置,点越多分类效果越好 下面是随机生成四个象限的数据
a = np.random.random((DATA_LENGTH, 2))
for i in range(int(a.shape[0] / 2)):
a[i, 1] = -a[i, 1]
b = -np.random.random((DATA_LENGTH, 2))
for i in range(int(b.shape[0] / 2)):
b[i, 1] = -b[i, 1]
label = np.ones([2 * DATA_LENGTH, 1])
for i in range(int(DATA_LENGTH / 2)):
label[i, 0] = 4
for i in range(DATA_LENGTH, DATA_LENGTH + int(DATA_LENGTH / 2)):
label[i, 0] = 2
for i in range(int(DATA_LENGTH / 2) + DATA_LENGTH, DATA_LENGTH * 2):
label[i, 0] = 3
label = label.squeeze()
ab = np.r_[a, b]
# print(ab)
# print(ab.shape)
n_sample = 2 * DATA_LENGTH
np.random.seed(0)
order = np.random.permutation(n_sample) #用来打乱数据
ab = ab[order]
label = label[order]
ab_train = ab[:int(.9 * n_sample)]
label_train = label[:int(.9 * n_sample)]
ab_test = ab[int(.9 * n_sample):]
label_test = label[int(.9 * n_sample):]
# print(label_test)
# print(label_train)
# print(label)
# 以上的代码生成了测试用的数据
clf = svm.SVC(kernel='linear') # 对数据进行分类
clf.fit(ab_train, label_train) # 对数据进行分类
print('Precision:', clf.score(ab_test, label_test)) #测试精度
# 下面是设置后面画图的颜色
cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF', '#AAAAAA'])
cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF', '#FFFFFF'])
def plot_estimator(estimator, x, y):
estimator.fit(x, y)
x_min, x_max = x[:, 0].min() - .1, x[:, 0].max() + .1
y_min, y_max = x[:, 1].min() - .1, x[:, 1].max() + .1
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100), np.linspace(y_min, y_max, 100))
z = estimator.predict(np.c_[xx.ravel(), yy.ravel()])
z = z.reshape(xx.shape)
pl.figure(0)
pl.pcolormesh(xx, yy, z, cmap=cmap_light)
pl.scatter(x[:, 0], x[:, 1], c=y, cmap=cmap_bold)
pl.axis('tight')
pl.axis('off')
pl.tight_layout()
pl.show()
plot_estimator(clf, ab_train, label_train) #画图
分类完成后的图片如下:
由于分类简单,基本上每次的精度都能达到97以上,点越多精度越高。