手写BP神经网络作用于IRIS数据集
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
############################################################################
#####激活函数用的sigmoid()适合二分问题,推荐使用ReLU函数
#####为了方便损失函数使用的是直接相减,
#####使用交叉熵函数更适合
############################################################################
def ReLu(x):
# ReLu 函数
x = (np.abs(x) + x) / 2.0
return x
def dReLu(x):
# ReLu 导数
x[x <= 0] = 0
x[x > 0] = 1
return x
iris=datasets.load_iris()
x=iris.data
y=iris.target
x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=0.3)
#print(x_train.shape)
#定义激活函数f(x),和他的导数
def f(x):
return 1/(1+np.exp(-x))
def df(x):
return x*(1-x)
##def change2d(x):
## a=np.array([])
## for i in range(len(x)):
## i1=np.atleast_1d(x[i])
## np.append(a,i1)
## return a
#预测结果计算
def predict(x):
L1=f(np.dot(x,w0))
L2=ReLu(np.dot(L1,w1))
L2=np.around(L2,0)#分类问题,四舍五入
L2=L2.flatten().astype(int)#转化为数组
return L2
def calc_acc(y_pre,y_test):
#求准确率
return sum(y_pre==y_test)/y_test.shape[0]
#####################
np.random.seed(1)#使前后所使用的随机数一样
#####################
w0=np.random.random((4,5))*2-1
w1=np.random.random((5,1))*2-1
#print(x_train.shape)
#print(w0.shape)
def train(x_train,y_train,number=10000,lr=0.01):
global w0,w1#这里一定要长点心,全局变量
y_train.resize((105, 1))
for i in range(number+1):
l0=x_train
l1=f(np.dot(l0,w0))
l2=ReLu(np.dot(l1,w1))#正向
El2=y_train-l2#这里简化了损失函数,多分问题使用交叉熵函数会更好
if i%500==0:#梯度下降每1000次打印观察一下结果
y_pred=predict(x_test)
#print("E"+str(np.mean(np.abs(El2))))
print("准确率是:%0.3f"%calc_acc(y_test,y_pred))
print(y_pred,y_test)
#反向传播
l2_r=El2*dReLu(l2)#偏导值105*1
El1=l2_r.dot(w1.T)
l1_r=El1*df(l1)
#print(l2_r.shape,l1.shape,w0.shape)
#更新
w1+=lr*l1.T.dot(l2_r)
w0+=lr*x_train.T.dot(l1_r)
train(x_train,y_train,3000)
##l0=x_train
##print(l0.shape)
##l1=f(np.dot(l0,w0))
##print(l1.shape)
##l2=f(np.dot(l1,w1))
##print(l2.shape)
##El2=y_train-l2
#y_train.resize((105, 1))
#print(y_train*(1-y_train))求导形状不会改变
##print(y_train)
##
###print(y_train)
###c=change2d(y_train)
###print(c)
##print(El2.shape)
看到这篇文章能给我点个赞嘛,QAQ