TensorFlow2.0程序实现鸢尾花数据集分类

该博客介绍了如何使用TensorFlow2.0进行鸢尾花数据集的分类。首先,通过sklearn导入数据并进行数据集的读入、乱序、分割。然后,构建神经网络模型,并定义训练参数。接着,通过嵌套循环迭代优化参数,同时显示损失(loss)。最后,计算并展示准确率(acc),以及可视化损失和准确率变化曲线。
摘要由CSDN通过智能技术生成

先回顾鸢尾花数据集,其提供了150组鸢尾花数据,每组包括鸢尾花的花萼长、花萼宽、花瓣长、花瓣宽 4个输入特征,同时还给出了这一组特征对应的鸢尾花类别。类别包括狗尾鸢尾、杂色鸢尾、弗吉尼亚鸢尾三类, 分别用数字0、1、2表示。使用此数据集代码如下:

from sklearn.datasets import load_iris
x_data = datasets.load_iris().data # 返回iris数据集所有输入特征
y_data = datasets.load_iris().target # 返回iris数据集所有标签

即从sklearn包中导出数据集,将输入特征赋值给x_data变量,将对应标签赋值给y_data变量。

程序实现

我们用神经网络实现鸢尾花分类仅需要三步:

(1)准备数据,包括数据集读入、数据集乱序,把训练集和测试集中的数据配成输入特征和标签对,生成train和test即永不相见的训练集和测试集;

(2)搭建网络,定义神经网络中的所有可训练参数;

(3)优化这些可训练的参数,利用嵌套循环在with结构中求得损失函数loss对每个可训练参数的偏导数,更改这些可训练参数,为了查看效果,程序中可以加入每遍历一次数据集显示当前准确率,还可以画出准确率acc和损失函数loss的变化曲线图。以上部分的完整代码与解析如下:

(1) 数据集读入:
from sklearn.datasets import datasets
x_data = datasets.load_iris().data # 返回iris数据集所有输入特征
y_data = datasets.load_iris().target # 返回iris数据集所有标签
(2) 数据集乱序:
np.random.seed(116) # 使用相同的seed,使输入特征/标签一一对应
np.random.shuffle(x_data)
np.random.seed(116)
np
  • 0
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值