鸢尾花_多层demo+详细注解

本文详细介绍了使用TensorFlow对鸢尾花数据集进行多层神经网络的训练过程,包括数据预处理、模型参数设置、训练和测试准确率计算以及损失函数的可视化。
摘要由CSDN通过智能技术生成

鸢尾花_多层demo+详细注解

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import tensorflow as tf
from sklearn import datasets


# 导入数据,分别为输入特征和标签
x_data = datasets.load_iris().data
y_data = datasets.load_iris().target

# 数据打乱
np.random.seed(116)
np.random.shuffle(x_data)
np.random.seed(116)
np.random.shuffle(y_data)
tf.random.set_seed(116)

# 2,数据预处理
# 2.1拆分样本特征和标签
x_train = x_data[:-30]
y_train = y_data[:-30]
print(y_train)
x_test = x_data
y_test = y_data
#print(x_test.shape, y_test.shape)
print(y_test)
#exit(0)
# 2.2数据归一化
# 由于样本的4个征征值尺度相同,不用进行归一化。

# 2.3数据中心化
# 需要按列中心化,所以指定axis=0
x_train = x_train - np.mean(x_train, axis=0)
x_test = x_test - np.mean(x_test, axis=0)
# print(x_train)
# print(x_test)

# 2.4类型转换
# 将训练集特征转为float32类型
X_train = tf.cast(x_train, tf.float32)
# 将训练集标签转为int32,再转为独热编码
Y_train = tf.one_hot(tf.constant(y_train, tf.int32), 3)
# print(Y_train[0:4,:])
# 将测试特征转为float32类型
X_test = tf.cast(x_test, tf.float32)
# 将测试标签转为int32,再转为独热编码
Y_test = tf.one_hot(tf.constant(y_test, tf.int32), 3)

# 3,设置超参数
# 学习率
learn_rate = 0.5
# 迭代次数
iter = 50
# 显示频率
display_step = 10
# 模型参数
np.random.seed(612)
W1 = tf.Variable(np.random.randn(4, 16), dtype=tf.float32)
B1 = tf.Variable(np.zeros([16]), dtype=tf.float32)
W2 = tf.Variable(np.random.randn(16, 3), dtype=tf.float32)
B2 = tf.Variable(np.zeros([3]), dtype=tf.float32)

# 4,训练模型
# 训练准确率
acc_train = []
# 测试准确率
acc_test = []
# 训练损失
cce_train = []
# 测试损失
cce_test = []

for i in range(0, iter + 1):
    with tf.GradientTape() as tape:
        # 训练集隐含层线性结果
        hidden_train = tf.matmul(X_train, W1) + B1
        # 训练集隐含层输出
        Hidden_train = tf.nn.relu(hidden_train)
        # 训练集输出层线性结果
        pred_train = tf.matmul(Hidden_train, W2) + B2
        # 训练集输出层输出
        PRED_train = tf.nn.softmax(pred_train)
        # 训练集交叉熵损失
        Loss_train = tf.reduce_mean(tf.keras.metrics.categorical_crossentropy(y_true=Y_train, y_pred=PRED_train))

        # 测试集隐含层线性输出
        Hidden_test = tf.nn.relu(tf.matmul(X_test, W1) + B1)
        # 测试集概率输出
        PRED_test = tf.nn.softmax(tf.matmul(Hidden_test, W2) + B2)
        # 测试集交叉熵损失
        Loss_test = tf.reduce_mean(tf.keras.metrics.categorical_crossentropy(y_true=Y_test, y_pred=PRED_test))

    # 添加交叉熵损失到列表
    cce_train.append(Loss_train)
    cce_test.append(Loss_test)
    # 计算准确率 tf.argmax(a,1)指在张量a的第一维度找到最大值的下标,并返回ndarray
    accuracy_train = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(PRED_train.numpy(), 1), y_train), tf.float32))
    accuracy_test = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(PRED_test.numpy(), 1), y_test), tf.float32))
    # 添加准确率到数组
    acc_train.append(accuracy_train)
    acc_test.append(accuracy_test)
    # 计算偏导数
    grads = tape.gradient(Loss_train, [W1, B1, W2, B2])
    # 更新参数
    W1.assign_sub(learn_rate * grads[0])
    B1.assign_sub(learn_rate * grads[1])
    W2.assign_sub(learn_rate * grads[2])
    B2.assign_sub(learn_rate * grads[3])
    if i % display_step == 0:
        print(i, ',训练集准确率:', accuracy_train.numpy(), ',训练集损失:', Loss_train.numpy(), ',测试集准确率:',
              accuracy_test.numpy(), ',测试集损失:', Loss_test.numpy())

# 5,可视化结果
plt.figure(figsize=(10, 3))
# 5.1绘制损失
plt.subplot(121)
# 绘制训练集损失
plt.plot(cce_train, color="blue", label="train")
# 绘制测试集损失
plt.plot(cce_test, color="red", label="test")
plt.xlabel("Iter")
plt.ylabel("cce")
plt.legend(["train", "test"])

# 5.2绘制准确率
plt.subplot(122)
# 绘制训练集准确率
plt.plot(acc_train, color="blue", label="train")
# 绘制测试集准确率
plt.plot(acc_test, color="red", label="test")
plt.xlabel("Iter")
plt.ylabel("acc")
plt.legend(["train", "test"])
plt.show()

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

java_leaf

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值