MNIST数据集简单测试

环境:TensorFlow 2.7.0, python 3.8

网络:全连接

特别注意:model.compile()参数设置时,metrics应该直接设置要监测的指标,我刚学比较菜,用了一个方法,没有报错,但是监测的准确率一直为零。

# -*- coding: utf-8 -*-

import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras import layers

# Model definition
model = Sequential([
    layers.Dense(256, 'relu'),
    layers.Dense(128, 'relu'),
    layers.Dense(64, 'relu'),
    layers.Dense(32, 'relu'),
    layers.Dense(16, 'relu'),
    layers.Dense(10, 'Softmax'),
])

# The shape of the model input
model.build((None, 28 * 28))

# Model structure
model.summary()

# Model optimizer
model.compile(
    optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
    loss=tf.keras.losses.CategoricalCrossentropy(),
    metrics=['accuracy'],
)

# Load datasets: MNIST
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
print('x_train.shape={}, y_train.shape={}'.format(x_train.shape, y_train.shape))

# Construct training objects
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_db = train_db.shuffle(10000)
train_db = train_db.batch(32)  # the batch size is 128


def preprocessing(x, y):
    x = tf.cast(x, tf.float32) / 255.
    x = tf.reshape(x, [-1, 28 * 28])
    y = tf.one_hot(y, depth=10)

    return x, y


# Split line
def line(num):
    for i in range(num):
        print('-', end='')


# Preprocess the data
train_db = train_db.map(preprocessing)

# TRAIN
print('TRAIN')
model.fit(train_db, epochs=20)

# TEST
print('TEST')
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.shuffle(10000)
test_db = test_db.batch(32)
test_db = test_db.map(preprocessing)
result = model.evaluate(test_db)
print(result)

运行结果

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense (Dense)               (None, 256)               200960    
                                                                 
 dense_1 (Dense)             (None, 128)               32896     
                                                                 
 dense_2 (Dense)             (None, 64)                8256      
                                                                 
 dense_3 (Dense)             (None, 32)                2080      
                                                                 
 dense_4 (Dense)             (None, 16)                528       
                                                                 
 dense_5 (Dense)             (None, 10)                170       
                                                                 
=================================================================
Total params: 244,890
Trainable params: 244,890
Non-trainable params: 0
_________________________________________________________________
x_train.shape=(60000, 28, 28), y_train.shape=(60000,)
TRAIN
Epoch 1/20
1875/1875 [==========================] - 2s 716us/step - loss: 2.2427 - accuracy: 0.2231
Epoch 2/20
1875/1875 [==========================] - 1s 723us/step - loss: 1.8994 - accuracy: 0.4011
Epoch 3/20
1875/1875 [==========================] - 1s 717us/step - loss: 1.2413 - accuracy: 0.6643
Epoch 4/20
1875/1875 [==========================] - 1s 729us/step - loss: 0.7615 - accuracy: 0.7864
Epoch 5/20
1875/1875 [==========================] - 1s 726us/step - loss: 0.5668 - accuracy: 0.8424
Epoch 6/20
1875/1875 [==========================] - 1s 725us/step - loss: 0.4663 - accuracy: 0.8701
Epoch 7/20
1875/1875 [==========================] - 1s 711us/step - loss: 0.4045 - accuracy: 0.8862
Epoch 8/20
1875/1875 [==========================] - 1s 707us/step - loss: 0.3621 - accuracy: 0.8969
Epoch 9/20
1875/1875 [==========================] - 1s 730us/step - loss: 0.3298 - accuracy: 0.9059
Epoch 10/20
1875/1875 [==========================] - 1s 723us/step - loss: 0.3033 - accuracy: 0.9131
Epoch 11/20
1875/1875 [==========================] - 1s 721us/step - loss: 0.2812 - accuracy: 0.9190
Epoch 12/20
1875/1875 [==========================] - 1s 716us/step - loss: 0.2625 - accuracy: 0.9237
Epoch 13/20
1875/1875 [==========================] - 1s 715us/step - loss: 0.2455 - accuracy: 0.9286
Epoch 14/20
1875/1875 [==========================] - 1s 722us/step - loss: 0.2306 - accuracy: 0.9329
Epoch 15/20
1875/1875 [==========================] - 1s 717us/step - loss: 0.2177 - accuracy: 0.9364
Epoch 16/20
1875/1875 [==========================] - 1s 722us/step - loss: 0.2061 - accuracy: 0.9402
Epoch 17/20
1875/1875 [==========================] - 1s 729us/step - loss: 0.1958 - accuracy: 0.9431
Epoch 18/20
1875/1875 [==========================] - 1s 728us/step - loss: 0.1860 - accuracy: 0.9458
Epoch 19/20
1875/1875 [==========================] - 1s 729us/step - loss: 0.1775 - accuracy: 0.9479
Epoch 20/20
1875/1875 [==========================] - 1s 725us/step - loss: 0.1696 - accuracy: 0.9500

TEST
313/313 [============================] - 0s 511us/step - loss: 0.1727 - accuracy: 0.9477
[0.17273546755313873, 0.947700023651123]

Process finished with exit code 0

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值