神经网络实现鸢尾花分类
数据集介绍:共有数组150组,每组包含花萼长、花萼宽、花瓣长、花瓣宽4个输入特征,
输出三个类别:狗尾草鸢尾,杂色鸢尾,弗吉尼亚鸢尾,分别用0,1,2表示
第一步:准备数据
1.数据集读入
从sklearn包datasets读入数据集:
from sklearn.datasets import datasets
x_data = datasets.load_iris().data #读入输入特征
y_data = datasets.load_iris().target #读入标签
2.数据集乱序(实现代码如下)
np.random.seed(116)#使用相同的随机种子,可以让打乱顺序后输入特征和标签仍然一一对应
np.random.shuffle(x_data)
np.random.seed(116)
np.random.shuffle(x_data)
np.random.seed(116)
3.生成训练集和测试集(即x_train/y_train,x_test/y_test),注意训练集和测试集应该是永不相见的
x_train = x_data[:30]#一共150个数据,前120个拿出来做训练集,后30个做测试集
y_train = y_train[:30]#训练集和数据集永不相见,能够公正评判神经网络的效果
x_test = x_data[-30:]
x_test = y_data[-30:]
4.配成特征-标签对,每次读取一小撮(batch)
train_db = tf.data.Dataset.from_tensor_slices((x_train,y_train)).batch(32)
#把训练集的输入特征和标签配对打包,每32组输入特征标签对,打包为一个batch,会以batch喂入神经网络
train_db = tf.data.Dataset.from_tensor_slices((x_test,y_test)).batch(32)
第二步:搭建网络
定义神经网络中的所有可训练参数
w1 = tf.Variable(tf.random.truncated_normal([4,3],stddev=0.1,seed=1))
#输入特征是4个,输出分类数是3个
b1 = tf.Variable(tf.random.truncated_normal([3],stddev=0.1,seed=1))
#b1是偏差量,必须与输出的维度保持一致
第三步:参数优化
利用嵌套循环,在with结构中求得损失函数loss对每个可训练参数的偏导,更新这些可训练参数
for epoch in range(epoch)#第一层for循环针对整个数据集进行循环
for step,(x_train,y_train) in enumerate(train_db):#第二层for循环针对batch
with tf.GradientTape() as tape:#记录梯度信息
前向传播过程计算y
计算总loss
grads=tape.gradient(loss,[w1,b1])
w1.assign_sub(lr*grads[0])#参数自更新
b1.assign_sub(lr*grads[1])
prant("Epoch{},loss:{}".format(epoch,loss_all/4))#打印这一轮epoch后的loss损失函数值,因为训练集有120个数据,每个step喂入的一个batch是32个,所以需要batch级别循环4次,loss_all/4求得每次step迭代的平均loss
第四步:测试效果
为了直观查看效果,可每遍历一次数据集,显示当前准确率
for x_test,y_test in test_db:
y = tf.matmul(h,w)+b#前向传播计算y
y = tf.nn.softmax(y)#使y的值符合概率分布
pred = tf.argmax(y,axis=1)#返回y中最大值的索引,也就是y的分类
pred = tf.cast(pred,dtype=y_test.dtype)#调整数据类型与标签一致
correct = tf.cast(tf.equal(pred,y_test),dtype=tf.int32)#如果预测值和标签相等
correct = tf.reduce_sum(correct)#correct自加一
total_correct+=int(correct)
total_number+=x_text.shape[0]
acc = total_correct/total_number
print("test_acc:",acc)#打印准确率
第五步:acc/loss可视化
为了直观查看效果,可以画出准确率acc和损失函数loss的变化曲线图
plt.title('Acc Curve')#图片标题
plt.xlable('Epoch')#x轴名称
plt.ylable('Acc')#y轴名称
plt.plot(test_acc,label="$Accuracy$")#逐点画出test_acc值并连线
plt.legend()
plt.show()
#用同样的方法也可以画出loss曲线