# -*- coding: utf-8 -*-
import fileinput
import numpy as np
def install(fileName):
xSet=[]
ySet=[]
for line in fileinput.input(fileName):
num_str = line.split()
xSet.append(map(float, num_str[0:-1]))
ySet.append(int(num_str[-1]))
return (np.matrix(xSet),np.matrix(ySet).T)
xSet,ySet=install('hw1_15_train.dat') #从训练集中读取
xSet = np.concatenate((np.ones((xSet.shape[0],1)), xSet), 1)
w = np.matrix(np.zeros(5)) #初始化w,[[ 0. 0. 0. 0. 0.]]
count = 0
while True:
correct_num = 0
for i in xrange(np.shape(xSet)[0]): #从martix里提取行数
xn = xSet[i]
yn = ySet[i]
dot = np.dot(xn,w.T)
if dot*yn <=0:
w += yn*xn
count = count + 1
else:
correct_num = correct_num + 1
if correct_num == 400:
break
else:
print correct_num
print count, w
Machine Learning Foundations q15
最新推荐文章于 2017-07-29 18:04:00 发布