1、导入数据
在使用随机过程(如随机数)的机器学习算法时,最好使用固定随机种子初始化随机数生成器。这样就可以重复运行相同的代码,并得到相同的结果。
实例中使用Numpy的函数loadtxt()函数加载Piman Indians数据集。Piman Indians数据集有8个输入维度和1个输出维度(最后一列)
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
# 设定随机数种子
np.random.seed(7)
# 导入数据
dataset = np.loadtxt('pima-indians-diabetes.csv', delimiter=',')
# 分割输入x和输出Y
x = dataset[:, 0 : 8]
Y = dataset[:, 8]
2、定义模型
在Keras中,通常使用Dense类来定义完全连接的层。
在本例中,通过Sequential的add()函数将层添加到模型,并组合在一起。使用ReLU作为前两层的激活函数,使用sigmoid作为输出层的激活函数,第一个隐藏层有12个神经元,使用8个输出变量,第二隐藏层有8个神经元,最后输出层有1个神经元来预测数据结果(是否患有糖尿病),代码如下;
# 创建模型
model = Sequential()
model.add(Dense(12, input_dim=8, activation='relu'))
model.add(Dense(8, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
3、编译模型
模型定义好后,需要对模型进行编译,编译模型是为了使模型能够有效地使用Keras封装的数值进行计算。在编译模型时,必须指定用于评估一组权重的损失函数(loss)、用于搜索网络不同权重的优化器(optimizer),以及希望在模型训练期间收集和报告的可选指标。
在这个例子中使用对数损失函数,作为模型的损失函数。在Keras中,对于二分类问题的对数损失函数定义为二进制交叉熵。使用有效的梯度下降算法Adam作为优化器,这是一个有效的默认值。
# 编译模型
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
4、训练模型
训练模型通过调用模型的fit()函数来实现。
# 训练模型
model.fit(x=x, y=Y, epochs=150, batch_size=10)
5、评估模型
可以使用模型的evaluation()函数来评估模型的准确度。在这个实例中使用训练集来评估模型的准确度,因此传给evaluation()函数的数据集与用于训练模型的数据集相同。、
# 评估模型
scores = model.evaluate(x=x, y=Y)
print('\n%s : %.2f%%' % (model.metrics_names[1], scores[1]*100))
6、汇总代码
到这里已经完成了基于Keras构建的第一个神经网络。下面给出完整的代码,并运行这个模型,代码而下;
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
# 设定随机数种子
np.random.seed(7)
# 导入数据
dataset = np.loadtxt('pima-indians-diabetes.csv', delimiter=',')
# 分割输入x和输出Y
x = dataset[:, 0 : 8]
Y = dataset[:, 8]
# 创建模型
model = Sequential()
model.add(Dense(12, input_dim=8, activation='relu'))
model.add(Dense(8, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
# 编译模型
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# 训练模型
model.fit(x=x, y=Y, epochs=150, batch_size=10)
# 评估模型
scores = model.evaluate(x=x, y=Y)
print('\n%s : %.2f%%' % (model.metrics_names[1], scores[1]*100))
运行结果: