TensorFlow实现多层神经网络进行fashion_mnist时装分类

import os
import numpy as np
import tensorflow as tf
from tensorflow.python import keras

# fashion_mnist数据集:70000张灰度图像,涵盖10个类别(T恤衫/上衣、裤子、套衫、裙子、外套、凉鞋、衬衫、运动鞋、包包)


class SingleNN(object):
    model = keras.models.Sequential([
        keras.layers.Flatten(input_shape=(28, 28)),
        keras.layers.Dense(128, activation=tf.nn.relu),
        keras.layers.Dense(128, activation=tf.nn.relu),
        keras.layers.Dense(10, activation=tf.nn.softmax)
    ])                                                               # 1、建立神经网络模型(类属性)

    def __init__(self):                                              # 2、获取数据集,进行数据的归一化(关键一步)
        (self.x_train, self.y_train), (self.x_test, self.y_test) = keras.datasets.fashion_mnist.load_data()
        self.x_train = self.x_train / 255.0
        self.x_test = self.x_test / 255.0

    def compile(self):                                               # 3、编译模型优化器、损失、准确率
        SingleNN.model.compile(optimizer=keras.optimizers.Adam(),
                               loss=keras.losses.sparse_categorical_crossentropy,
                               metrics=['accuracy'])
        return None

    def fit(self):                                                   # 4、模型训练
        # 调用TensorBoard回调函数,保存事件文件在tensorboard中实时观察
        tensorboard = keras.callbacks.TensorBoard(log_dir='./summary/', write_graph=True, write_images=True)  # 参数histogram_freq=1加上报错
        SingleNN.model.fit(self.x_train, self.y_train, epochs=1, batch_size=32, callbacks=[tensorboard])
        return None

    def evaluate(self):                                              # 5、模型评估
        test_loss, test_acc = SingleNN.model.evaluate(self.x_test, self.y_test)
        print("损失评估结果:", test_loss)
        print("准确率评估结果:", test_acc)
        return None

    def predict(self):                                               # 6、使用训练好的模型,进行预测
        if os.path.exists("./ckpt/SingleNN.h5"):
            SingleNN.model.load_weights("./ckpt/SingleNN.h5")
        predictions = SingleNN.model.predict(self.x_test)
        # 对预测结果进行处理
        print(predictions, predictions.shape)
        print("真实标记:", self.y_test)
        print("预测标记:", np.argmax(predictions, axis=1))  # 获取最大预测概率的编号(0~9)
        return

if __name__ == '__main__':
    SingleNN.model.summary()  # 打印模型详情
    snn = SingleNN()
    # snn.compile()
    # snn.fit()
    # snn.evaluate()
    # SingleNN.model.save_weights("./ckpt/SingleNN.h5")
    snn.predict()


===运行结果:=============================================

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 128)               100480    
_________________________________________________________________
dense_1 (Dense)              (None, 128)               16512     
_________________________________________________________________
dense_2 (Dense)              (None, 10)                1290      
=================================================================
Total params: 118,282
Trainable params: 118,282
Non-trainable params: 0
_________________________________________________________________
[[5.0759772e-05 3.5075584e-06 2.7417283e-05 ... 3.0054134e-01
  5.4775714e-04 5.3799856e-01]
 [5.0181360e-04 1.0977338e-06 9.1766268e-01 ... 1.3649445e-08
  1.0163025e-04 2.8169373e-07]
 [5.9898583e-05 9.9991202e-01 5.0342514e-06 ... 2.1224409e-08
  5.5229225e-07 1.6276662e-09]
 ...
 [1.0267869e-03 2.0792443e-07 4.8083955e-05 ... 2.5554016e-06
  9.9812120e-01 2.0842114e-07]
 [3.2958884e-05 9.9924195e-01 1.2787878e-05 ... 6.2861568e-06
  4.5372763e-06 9.5869041e-07]
 [1.1158469e-03 3.8320984e-05 8.4454805e-04 ... 1.5706961e-01
  1.2674028e-02 3.6080026e-03]] (10000, 10)
真实标记: [9 2 1 ... 8 1 5]
预测标记: [9 2 1 ... 8 1 5]
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值