全连接网络鸢尾花数据集

import tensorflow as tf
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# 加载数据
TRAIN_URL = "http://download.tensorflow.org/data/iris_training.csv"
train_path = tf.keras.utils.get_file(TRAIN_URL.split('/')[-1], TRAIN_URL)
TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"
test_path = tf.keras.utils.get_file(TEST_URL.split('/')[-1], TEST_URL)

df_iris_train = pd.read_csv(train_path, header=0)
df_iris_test = pd.read_csv(test_path, header=0)

iris_train = np.array(df_iris_train)
iris_test = np.array(df_iris_test)

# print(iris_train.shape, iris_test.shape)

x_train = iris_train[:, 0:4]
y_train = iris_train[:, 4]
x_test = iris_test[:, 0:4]
y_test = iris_test[:, 4]

x_train = x_train - np.mean(x_train, axis=0)
x_test = x_test - np.mean(x_test, axis=0)

X_train = tf.cast(x_train, tf.float32)
Y_train = tf.one_hot(tf.constant(y_train, dtype=tf.int32), 3)
X_test = tf.cast(x_test, tf.float32)
Y_test = tf.one_hot(tf.constant(y_test, dtype=tf.int32), 3)

learn_rate = 0.5
iter = 50
display_step = 10

np.random.seed(612)
W1 = tf.Variable(np.random.randn(4, 16), dtype=tf.float32)
B1 = tf.Variable(np.zeros([16]), dtype=tf.float32)
W2 = tf.Variable(np.random.randn(16, 3), dtype=tf.float32)
B2 = tf.Variable(np.zeros([3]), dtype=tf.float32)

acc_train = []
acc_test = []
cce_train = []
cce_test = []

for i in range(0, iter + 1):
    with tf.GradientTape() as tape:
        Hidden_train = tf.nn.relu(tf.matmul(X_train, W1) + B1)
        PRED_train = tf.nn.softmax(tf.matmul(Hidden_train, W2) + B2)
        Loss_train = tf.reduce_mean(tf.keras.losses.categorical_crossentropy(y_true=Y_train, y_pred=PRED_train))

    Hidden_test = tf.nn.relu(tf.matmul(X_test, W1) + B1)
    PRED_test = tf.nn.softmax(tf.matmul(Hidden_test, W2) + B2)
    Loss_test = tf.reduce_mean(tf.keras.losses.categorical_crossentropy(y_true=Y_test, y_pred=PRED_test))

    accuracy_train = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(PRED_train.numpy(), axis=1), y_train), tf.float32))
    accuracy_test = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(PRED_test.numpy(), axis=1), y_test), tf.float32))

    acc_train.append(accuracy_train)
    acc_test.append(accuracy_test)
    cce_train.append(Loss_train)
    cce_test.append(Loss_test)

    grads = tape.gradient(Loss_train, [W1, B1, W2, B2])
    W1.assign_sub(learn_rate * grads[0])
    B1.assign_sub(learn_rate * grads[1])
    W2.assign_sub(learn_rate * grads[2])
    B2.assign_sub(learn_rate * grads[3])

    if i % display_step == 0:
        print("%i :%f :%f :%f :%f" % (i, accuracy_train, Loss_train, accuracy_test, Loss_test))


plt.figure()

plt.subplot(211)
plt.plot(cce_train)
plt.plot(cce_test)

plt.subplot(212)
plt.plot(acc_train)
plt.plot(acc_test)

plt.show()

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

织蛾

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值