tensorflow hello world mnist

import tensorflow as tf
# tensorflow==2.1.0

# 载入并准备好 MNIST 数据集。将样本从整数转换为浮点数:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# 0~9
class_nums = 10
epochs = 5

# 将模型的各层堆叠起来,以搭建 tf.keras.Sequential 模型。为训练选择优化器和损失函数:
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(class_nums, activation='softmax')
])

# 训练并验证模型:
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=epochs)
model.evaluate(x_test, y_test, verbose=2)
# 现在,这个照片分类器的准确度已经达到 98%。想要了解更多,请阅读 TensorFlow 教程。

# 2020-07-07 guangjinzheng tensorflow course

 

import tensorflow as tf
# tensorflow==2.1.0
import matplotlib.pyplot as plt
from matplotlib.pyplot import MultipleLocator

# 载入并准备好 MNIST 数据集。将样本从整数转换为浮点数:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# 0~9
class_nums = 10
learn_rate = 1e-4
epochs = 30   # 5;20
batch_sizes = 32

# 将模型的各层堆叠起来,以搭建 tf.keras.Sequential 模型。为训练选择优化器和损失函数:
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(class_nums, activation='softmax')
])

# 训练并验证模型:
# model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.compile(optimizer=tf.keras.optimizers.Adam(lr=learn_rate), 
              loss='sparse_categorical_crossentropy', metrics=['accuracy'])
history = model.fit(x_train, y_train, epochs=epochs, batch_size=batch_sizes, validation_data=(x_test, y_test))
model.evaluate(x_test, y_test, verbose=2)
print('{}'.format(history.history))

# 图可视化
def pltshow(loss, val_loss, accuracy, val_accuracy):
    epochs_range = range(epochs)

    plt.figure(figsize=(8, 8))
    plt.subplot(1, 2, 1)
    plt.plot(epochs_range, accuracy, label='Training Accuracy')
    plt.plot(epochs_range, val_accuracy, label='Validation Accuracy')
    plt.legend(loc='lower right')
    plt.title('Training and Validation Accuracy')
    plt.xlabel('Epoch', fontsize=14)
    plt.ylabel('Accuracy', fontsize=14)
    ax = plt.gca()
    # ax.xaxis.set_major_locator(MultipleLocator(5))
    ax.yaxis.set_major_locator(MultipleLocator(0.05))
    plt.xlim(0, epochs)
    plt.ylim(0, 1)
    plt.grid()

    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, loss, label='Training Loss')
    plt.plot(epochs_range, val_loss, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch', fontsize=14)
    plt.ylabel('Loss', fontsize=14)
    ax = plt.gca()
    # ax.xaxis.set_major_locator(MultipleLocator(5))
    ax.yaxis.set_major_locator(MultipleLocator(0.1))
    plt.xlim(0, epochs)
    plt.ylim(0, 1)
    plt.grid()
    plt.show()

# 训练可视化
def history_show(history):
    loss = history['loss']
    val_loss = history['val_loss']
    accuracy = history['accuracy']
    val_accuracy = history['val_accuracy']
    pltshow(loss, val_loss, accuracy, val_accuracy)

history_show(history.history)

# 2020-09-27 guangjinzheng my mnist

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值