Tensorflow2.x.x最基础的神经网络(ANN)

Tensorflow2.x.x最基础的神经网络(ANN)

本章节主要使用Tensorflow2.x.x来搭建ANN神经网络。

ANN原理

这里直接放上小伙伴ANN的原理博客~

实现

使用ANN实现对MNIST数据集的分类。

import tensorflow as tf
# mnist数据集
from tensorflow.keras.datasets import mnist
# Adam优化器
from tensorflow.keras.optimizers import Adam
# 交叉熵损失函数,一般用于多分类
from tensorflow.keras.losses import CategoricalCrossentropy
# 模型和网络层
from tensorflow.keras import Model, layers

# 批次大小
BATCH_SIZE = 128
# 迭代次数
EPOCHS = 10
# 加载mnist的训练、测试数据集
train, test = mnist.load_data()
# 数据集的预处理
@tf.function
def preprocess(x, y):
    # 将x一维数据转为3维灰度图
    x = tf.reshape(x, [28, 28, 1])
    # 将x的范围由[0, 255]为[0, 1]
    x = tf.image.convert_image_dtype(x, tf.float32)
    # 将y数字标签进行独热编码
    y = tf.one_hot(y, 10)
    # 返回处理后的x和y
    return x, y

# 使用Dataset来减少内存的使用
train = tf.data.Dataset.from_tensor_slices(train)
# 对数据进行预处理并且给定BATCH_SIZE
train = train.map(preprocess).batch(BATCH_SIZE)

# test数据集同理
test = tf.data.Dataset.from_tensor_slices(test)
test = test.map(preprocess).batch(BATCH_SIZE)

# 搭建模型(只是其中的一种搭建方式而已)
x = layers.Input(shape=(28, 28, 1))                 # 输入为x, 大小为 28*28*1
y = layers.Flatten()(x)                             # 将高维数据扁平化
y = layers.Dense(1024, activation='relu')(y)        # 输出1024个神经元的全网络层
y = layers.Dense(512, activation='relu')(y)         # 输出512个神经元的全网络层
y = layers.Dense(256, activation='relu')(y)         # 输出256个神经元的全网络层
y = layers.Dense(128, activation='relu')(y)         # 输出128个神经元的全网络层
y = layers.Dense(64, activation='relu')(y)          # 输出64个神经元的全网络层
y = layers.Dense(32, activation='relu')(y)          # 输出32个神经元的全网络层
y = layers.Dense(10, activation='softmax')(y)       # 输出10个神经元的全网络层,最后一层使用了softmax进行激活,原因是我们希望提前[0, 1]之间的概率

# 创建模型
ann = Model(x, y)
# 编译模型,选择优化器、评估标准、损失函数
ann.compile(optimizer=Adam(), metrics=['acc'], loss=CategoricalCrossentropy())
# 进行模型训练
history = ann.fit(train, epochs=EPOCHS)
# 测试集的评估
score = ann.evaluate(test)
# 打印评估成绩
print('loss: {0}, acc: {1}'.format(score[0], score[1])) # loss: 0.11106619730560828, acc: 0.9769999980926514

# 绘制训练过程中每个epoch的loss和acc的折线图
import matplotlib.pyplot as plt
# history对象中有history字典, 字典中存储着“损失”和“评估标准”
epochs = range(EPOCHS)
fig = plt.figure(figsize=(15, 5), dpi=100)

ax1 = fig.add_subplot(1, 2, 1)
ax1.plot(epochs, history.history['loss'])
ax1.set_title('loss graph')
ax1.set_xlabel('epochs')
ax1.set_ylabel('loss val')

ax2 = fig.add_subplot(1, 2, 2)
ax2.plot(epochs, history.history['acc'])
ax2.set_title('acc graph')
ax2.set_xlabel('epochs')
ax2.set_ylabel('acc val')

fig.show()

输出结果如下:
在这里插入图片描述

展开阅读全文

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 大白 设计师: CSDN官方博客
应支付0元
点击重新获取
扫码支付

支付成功即可阅读