用NumPy对两个非线性可分的类进行分类:
程序:
#1) 导入函数库和数据集
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
#解决中文乱码问题
plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False #用来正常显示负号
#使用scikit-learn中的make_circles函数
from sklearn.datasets import make_circles
SEED=2017
#2)将导入的数据进行分组
#创建内圈和外圈
x,y=make_circles(n_samples=400,factor=.3,noise=.05,random_state=2017)
outer=y==0
inner=y==1
#3)绘制数据的分布来显示两个类
plt.title('2 circles')
plt.plot(x[outer,0],x[outer,1],"ro")
plt.plot(x[inner,0],x[inner,1],"bo")
plt.show()
#4) 标准化数据,确保两个圆的圆心是(1,1)
x=x+1
#5) 为了确保算法性能,对数据进行分割
x_train,x_val,y_train,y_val=train_test_split(x,y,test_size=0.2,random_state=SEED)
#6) 线性激活函数不起作用,使用sigmoid函数
def sigmoid(x):
return 1/(1+np.exp(-x))
#7) 定义超参数
n_hidden=50#隐层神经元数目
n_epochs=1000
learning_rate=1
#8) 初始化权重和其他变量
#初始化权值
weights_hidden=np.random.normal(0.0,size=(x_train.shape[1],n_hidden))
weights_output=np.random.normal(0.0,size=(n_hidden))
hist_loss=[]
hist_accuracy=[]
#9) 运行单层神经网络并输出统计信息
for i in range(n_epochs):
del_w_hidden=np.zeros(weights_hidden.shape)
del_w_output = np.zeros(weights_output.shape)
#按批量1循环加载训练数据
for x_,y_ in zip(x_train,y_train):
#前向计算
hidden_input=np.dot(x_,weights_hidden)
hidden_output=sigmoid(hidden_input)
output=sigmoid(np.dot(hidden_output,weights_output))
#向后计算
error=y_-output
output_error=error*output*(1-output)
hidden_error=np.dot(output_error,weights_output)*hidden_output*(1-hidden_output)
del_w_output+=output_error*hidden_output
del_w_hidden+=hidden_error*x_[:,None]
#更新权值
weights_hidden += learning_rate * del_w_hidden / x_train.shape[0]
weights_output += learning_rate * del_w_output / x_train.shape[0]
#输出状态
if i %100==0:
hidden_output=sigmoid(np.dot(x_val,weights_hidden))
out=sigmoid(np.dot(hidden_output,weights_output))
loss=np.mean((out-y_val)**2)
#最终预测值基于阈值0.5
predictions=out>0.5
accuracy=np.mean(predictions==y_val)
print("epoch: ",'{:>4}'.format(i),
"; validation loss: ",'{:>6}'.format(loss.round(4)),
"; validation accuracy: ",
'{:>6}'.format(accuracy.round(4)))
结果: