数据集来自mnist数据集,主要利用numpy里的matrix矩阵计算。算法的实现主要在于对参数w和b的求解。算法的推导过程参考李航《统计学习方法》,推导最优化函数然后更新参数的过程。
import numpy as np
import pandas as pd
import time
def data_load(filename):
'''
:param filename:
:return: dataArr,labelArr
'''
print('start read file')
dataArr,labelArr = [],[]
with open(filename,'r') as f:
lines = f.readlines()
for line in lines:
line = line.strip().split(',')
if int(line[0]) >= 5:
labelArr.append(1)
else:
labelArr.append(-1)
dataArr.append([int(num)/255 for num in line[1:]])
print('End')
return dataArr,labelArr
def perception(dataArr,labelArr,iter=50,lr=0.001):
dataMat = np.mat(dataArr)
m,n = dataMat.shape
labelMat = np.mat(labelArr).T
w = np.zeros((1,n))
b = 0
# loss = 0
for k in range(iter):
loss = 0
for i in range(len(dataArr)):
x_i = dataMat[i]
y_i = labelMat[i]
if y_i * (w * x_i.T + b) <= 0:
w = w + lr * y_i * x_i
b = b + lr * y_i
new_y = w * x_i.T + b
loss += abs(new_y - y_i)
print('iter:{}'.format(k))
return w,b
def data_test(dataArr,labelArr,w,b):
dataMat = np.mat(dataArr)
labelMat = np.mat(labelArr).T
sum = len(dataArr)
rigSum = 0
for i in range(len(dataArr)):
x_i = dataMat[i]
y_i = labelMat[i]
result = -1 * y_i * (w * x_i.T + b)
if result <= 0:
rigSum += 1
return rigSum/sum*100
if __name__ == '__main__':
dataArr, labelArr = data_load('../dataset/mnist_train.csv')
start = time.time()
w,b = perception(dataArr,labelArr)
end = time.time()
print(data_test(dataArr,labelArr,w,b))
print(end - start)