习题5.5: 试编程实现标准BP算法和累积BP算法,在西瓜数据集3.0上分别用这两个算法训练一个单隐层网络,并进行比较
算法的主要思想来自周志华《机器学习》上讲BP算法的部分,实现了书上介绍的标准BP算法和累积BP算法,对于西瓜数据集3.0,已经把文字部分的取值变为离散的数字了
如果要求解亦或问题,把下面的代码注释取消即可
x = np.mat( '1,1,2,2;\
1,2,1,2\
').T
x = np.array(x)
y=np.mat('0,1,1,0')
y = np.array(y).T
之前写过一版(戳这里查看初级版),全是通过for循环自己慢慢修改参数,这一版借助numpy矩阵运算的操作,使得代码量大大简化,并且运行的时间也比之前的版本快不少。
#!/usr/bin/python
#-*- coding:utf-8 -*-
############################
#File Name: bp-watermelon3.py
#Author: No One
#E-mail: 1130395634@qq.com
#Created Time: 2017-02-23 13:30:35
############################
import numpy as np
import math
from sys import argv
x = np.mat( '2,3,3,2,1,2,3,3,3,2,1,1,2,1,3,1,2;\
1,1,1,1,1,2,2,2,2,3,3,1,2,2,2,1,1;\
2,3,2,3,2,2,2,2,3,1,1,2,2,3,2,2,3;\
3,3,3,3,3,3,2,3,2,3,1,1,2,2,3,1,2;\
1,1,1,1,1,2,2,2,2,3,3,3,1,1,2,3,2;\
1,1,1,1,1,2,2,1,1,2,1,2,1,1,2,1,1;\
0.697,0.774,0.634,0.668,0.556,0.403,0.481,0.437,0.666,0.243,0.245,0.343,0.639,0.657,0.360,0.593,0.719;\
0.460,0.376,0.264,0.318,0.215,0.237,0.149,0.211,0.091,0.267,0.057,0.099,0.161,0.198,0.370,0.042,0.103\
').T
x = np.array(x)
y = np.mat('1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0')
y = np.array(y).T
'''
x = np.mat( '1,1,2,2;\
1,2,1,2\
').T
x = np.array(x)
y=np.mat('0,1,1,0')
y = np.array(y).T
'''
xrow, xcol = x.shape
yrow, ycol = y.shape
np.random.seed(0)
print 'x: \n', x
print 'y: \n', y
def sigmoid(x):
return 1.0 / (1.0 + np.exp(-x))
def printParam(v, w, t0, t1):
print 'v:', v
print 'w: ', w
print 't0: ', t0
print 't1: ', t1
def bpa(x, y, n_hidden_layer, r, error, n_max_train):
print
print 'all bp algorithm'
print '------------------------------------'
print 'init param'
[xrow, xcol] = x.shape
[yrow, ycol] = y.shape
v = np.random.random((xcol, n_hidden_layer))
w = np.random.random((n_hidden_layer, ycol))
t0 = np.random.random((1, n_hidden_layer))
t1 = np.random.random((1, ycol))
print '---------- train begins ----------'
n_train = 0
yo = 0
loss = 0
while 1:
b = sigmoid(x.dot(v) - t0)
yo = sigmoid(b.dot(w) - t1)
loss = sum((yo - y)**2) / xrow
if loss < error or n_train > n_max_train:
break
n_train += 1
# update param
g = yo * (1 - yo) * (y - yo)
w += r * b.T.dot(g)
t1 -= r * g.sum(axis = 0)
e = b * (1 - b) * g * w.T
v += r * x.T.dot(e)
t0 -= r * e.sum(axis = 0)
if n_train % 10000 == 0:
print 'train count: ', n_train
print np.hstack((y, yo))
print
print '---------- train ends ----------'
print 'train count = ', n_train
yo = yo.tolist()
print '---------- learned param: ----------'
printParam(v, w, t0, t1)
print '---------- result: ----------'
print np.hstack((y, yo))
print 'loss: ', loss
def bps(x, y, n_hidden_layer, r, error, n_max_train):
print
print 'standard bp algorithm'
print '------------------------------------'
print 'init param'
[xrow, xcol] = x.shape
[yrow, ycol] = y.shape
v = np.random.random((xcol, n_hidden_layer))
w = np.random.random((n_hidden_layer, ycol))
t0 = np.random.random((1, n_hidden_layer))
t1 = np.random.random((1, ycol))
print '---------- train begins ----------'
n_train = 0
tag = 0
yo = 0
loss = 0
while 1:
for k in range(len(x)):
b = sigmoid(x.dot(v) - t0)
yo = sigmoid(b.dot(w) - t1)
loss = sum((yo - y)**2) / xrow
if loss < error or n_train > n_max_train:
tag = 1
break
b = b[k]
b = b.reshape(1,b.size)
n_train += 1
g = yo[k] * (1 - yo[k]) * (y[k] - yo[k])
g = g.reshape(1,g.size)
w += r * b.T.dot(g)
t1 -= r * g
e = b * (1 - b) * g * w.T
v += r * x[k].reshape(1, x[k].size).T.dot(e)
t0 -= r * e
if n_train % 10000 == 0:
print 'train count: ', n_train
print np.hstack((y, yo))
if tag:
break
print
print '---------- train ends ----------'
print 'train count = ', n_train
yo = yo.tolist()
print '---------- learned param: ----------'
printParam(v, w, t0, t1)
print '---------- result: ----------'
print np.hstack((y, yo))
print 'loss: ', loss
r = 0.1
error = 0.001
n_max_train = 1000000
n_hidden_layer = 5
n = int(argv[1])
if n == 1:
bpa(x, y, n_hidden_layer, r, error, n_max_train)
elif n == 2:
bps(x, y, n_hidden_layer, r, error, n_max_train)
else:
print '命令行参数错误'
命令行输入: python test.py 1 # 1表示运行累积bp算法,2表示标准bp算法
结果如下
---------- train ends ----------
train count = 10472
---------- learned param: ----------
v: [[ 0.73242941 3.65170127 0.59713105 0.53589607 4.26680198]
[-0.47920127 0.38050143 0.88684761 0.96043754 -5.04922845]
[-3.53478658 -2.43632002 0.5617708 0.91791984 -3.99160595]
[ 2.72776748 -3.03747142 0.82596831 0.7629904 3.58719733]
[-0.49817982 0.11022257 0.45451013 0.77793538 -2.06655661]
[ 1.31898792 2.91731759 0.94006976 0.51654301 4.99262637]
[-2.87599092 -1.20602034 0.4544875 0.56614856 -2.14762434]
[ 3.31012315 2.37538414 0.61649596 0.94252131 0.76600351]]
w: [[ -7.57093041]
[ -4.87555553]
[ -0.60132992]
[ -1.24255911]
[ 11.75120165]]
t0: [[ 2.86694039 1.63790548 0.13229702 0.31939894 1.83144759]]
t1: [[ 1.86987471]]
---------- result: ----------
[[ 1.00000000e+00 9.93190538e-01]
[ 1.00000000e+00 9.99558269e-01]
[ 1.00000000e+00 9.73273387e-01]
[ 1.00000000e+00 9.98817906e-01]
[ 1.00000000e+00 9.95520603e-01]
[ 1.00000000e+00 9.58776391e-01]
[ 1.00000000e+00 9.26738291e-01]
[ 1.00000000e+00 9.78479082e-01]
[ 0.00000000e+00 5.84289232e-03]
[ 0.00000000e+00 6.31392712e-03]
[ 0.00000000e+00 8.31158755e-04]
[ 0.00000000e+00 1.51786116e-03]
[ 0.00000000e+00 2.72394938e-02]
[ 0.00000000e+00 2.37542259e-02]
[ 0.00000000e+00 7.79277689e-02]
[ 0.00000000e+00 1.85295127e-02]
[ 0.00000000e+00 2.97535714e-02]]
loss: [ 0.00099981]