本程序是将统计学习第二章的例题2.1用代码运算出来
#-*- coding:utf-8 -*-
import os
import sys
from numpy import *
reload(sys)
sys.setdefaultencoding('utf-8')
#首先将书上的训练集输入
def loadDataset():
dataset=[[3,3],[4,3],[1,1]]
labels=[1,1,-1]
return dataset,labels
#定义sign函数
def function_sign(x):
if x>=0:return 1
else:return -1
#算法主程序
def perceptron(dataSet,labels):
dataMat=mat(dataSet) #将数据集转换成矩阵用于计算
m,n=shape(dataMat)
w=mat(zeros((n,1)));b=0 #创建与示例维度相同的初始化w
while([function_sign(i) for i in array(dataMat * w+b)]) != labels: #当预测结果与实际结果不符时则循环更新我,w,b
for i in xrange(m):
if labels[i]*(dataMat[i]*w+b)<=0:
w += labels[i]*dataMat[i].T
b += labels[i]
return w,b
#输出运行结果w,b
dataSet,labels=loadDataset()
w,b=perceptron(dataSet,labels)
print w,'\n',b
结果如下:
[[ 1.]
[ 1.]]
-3
made by zcl at CUMT