PLA

15

import numpy as np

# 数据处理
def getData(file_name):
f = open(file_name)
data = []
line = [float(v) for v in line.split()]
line.insert(0,1.0)
line = tuple(line)
data.append(line)
return np.array(data)

def sign(y):
if y<=0 : return -1
else : return 1

data = getData("./data.txt")
print(data.shape)
train_time = 0
# np.random.shuffle(data) # 打乱数据, 打乱数据后结果不同, 第二题
w = np.array([0]*5,dtype=float)
while True:
isFinish = True
for index in range(data.shape[0]):
y = w.dot(data[index][:5])
if sign(y) != data[index][-1]:
w += data[index][:5] * data[index][-1]
# w += data[index][:5] * data[index][-1] * 0.5 # 第三题
isFinish = False
train_time += 1
print("update paramter:",train_time)
if isFinish == True:
break


18

pocket: 咋所有的错误点中进行参数更新， 下端代码中，每次迭代随机选取50个错误点，进行参数更新，如果更后的参数使得错误率降低，则更改最优参数为当前参数，否在则不进行修改。

import numpy as np

# 数据处理
def getData(file_name):
f = open(file_name)
data = []
line = [float(v) for v in line.split()]
line.insert(0,1.0)
line = tuple(line)
data.append(line)
return np.array(data)

# sign
def sign(yhat):
yhat = np.sign(yhat)
yhat[np.where(yhat==0)] = -1
return yhat

# 计算错误率
def err_rate(yhat,data):
return np.sum(yhat != data[:,-1])/yhat.shape[0]

def Pla_train(data, w, iternum):
yhat = sign(data[:,:5].dot(w.T))
errodle = err_rate(yhat,data)
best_w = w.copy()
for t in range(iternum):
index = np.where(yhat != data[:,-1])[0]
#print(index)
if not index.any():
break
# 随机挑选错误的进行更新, 打乱，挑选第一个进行更新
pos = index[np.random.permutation(len(index))[0]]
# 更新参数
w += data[pos][:5] * data[pos][-1]
# 新的yhat
yhat = sign(data[:,:5].dot(w.T))
errnew = err_rate(yhat,data)
if errnew < errodle:
best_w = w.copy()
return best_w, errnew

# 18
def train_18():
# 读入数据
data = getData("./train.txt")
# 初始化参数
w = np.array([0,0,0,0,0],dtype=np.float)
# 读入测试数据
data_test = getData("./test.txt")
err_test = 0
for i in range(2000):
w, err_r = Pla_train(data,w,50)
if i % 100 ==0:
print("当前训练错误率:",err_r)
# 输出测试错误率
yhat_test = sign(data_test[:,:5].dot(w.T))
err_test += err_rate(yhat_test,data_test)
print("平均测试集错误率：",err_test/2000)
# 最终测试错误率
yhat_last = sign(data_test[:,:5].dot(w.T))
print("最终测试集错误率：", err_rate(yhat_last,data_test))

if __name__ == "__main__":
train_18()


©️2019 CSDN 皮肤主题: 大白 设计师: CSDN官方博客