# 神经网络实现鸢尾花分类
import tensorflow as tf
import numpy as np
# 读取数据集
from sklearn.datasets import load_iris
iris = load_iris()
x_data = iris['data']
y_data = iris['target']
# 使数据集乱序
np.random.seed(116)
np.random.shuffle(x_data) # 打乱数据
np.random.seed(116)
np.random.shuffle(y_data) # 打乱数据
tf.random.set_seed(116)
# 将数据集划分成训练集和测试集
x_train = x_data[:-30]
y_train = y_data[:-30]
assert len(x_train) == len(y_train)
x_test = x_data[-30:]
y_test = y_data[-30:]
assert len(x_test)==len(y_test)
# 转换x的数据类型,否则后面矩阵相乘时会因数据类型不一致报错
x_train = tf.cast(x_train, tf.float32)
x_test = tf.
1.神经网络实现鸢尾花分类(单层全连接——学习使用)
最新推荐文章于 2024-04-19 17:48:27 发布