TensorFlow——基于Keras子类API的fashion-mnist数据集图像分类

https://tensorflow.google.cn/tutorials/keras/classification  

解决方案 

#!usr/bin/env python
# -*- coding:utf-8 _*-
"""
@version: 0.0.1
author: ShenTuZhiGang
@time: 2021/01/25 16:33
@file: 12.py
@function:
@modify:
"""

from tensorflow import keras
import tensorflow as tf
import mnist_reader
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import summary
import datetime


current_time = str(datetime.datetime.now().timestamp())
train_log_dir = '/content/drive/My Drive/colab notebooks/output/tsboardx/train/' + current_time
test_log_dir = '/content/drive/My Drive/colab notebooks/output/tsboardx/test/' + current_time
val_log_dir = '/content/drive/My Drive/colab notebooks/output/tsboardx/val/' + current_time
train_summary_writer = summary.create_file_writer(train_log_dir)
val_summary_writer = summary.create_file_writer(val_log_dir)
test_summary_writer = summary.create_file_writer(test_log_dir)
(train_images, train_labels), (test_images, test_labels) = mnist_reader.load_data('../data/fashion')
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
train_images = train_images / 255.0

test_images = test_images / 255.0

plt.figure(figsize=(10,10))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[i], cmap=plt.cm.binary)
    plt.xlabel(class_names[train_labels[i]])
plt.show()

class FashionMnistModel(keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.input_ = keras.layers.Flatten(input_shape=[28, 28])
        self.hidden1 = keras.layers.Dense(128, activation="relu")
        self.main_output = keras.layers.Dense(10)

    def call(self, inputs, **kwargs):
        input_a = self.input_(inputs)
        hidden1 = self.hidden1(input_a)
        output = self.main_output(hidden1)
        return output


model = FashionMnistModel()
model.build(input_shape=(0, 28, 28))
model.summary()
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
history = model.fit(train_images, train_labels, epochs=10)
test_loss, test_acc = model.evaluate(test_images,  test_labels, verbose=2)
with test_summary_writer.as_default():
    summary.scalar('loss', test_loss, 10)
    summary.scalar('accuracy', test_acc, 10)
print('\nTest accuracy:', test_acc)
probability_model = tf.keras.Sequential([model,
                                         tf.keras.layers.Softmax()])
predictions = probability_model.predict(test_images)
print(predictions[0])
print(np.argmax(predictions[0]))
print(test_labels[0])


def plot_image(i, predictions_array, true_label, img):
  predictions_array, true_label, img = predictions_array, true_label[i], img[i]
  plt.grid(False)
  plt.xticks([])
  plt.yticks([])

  plt.imshow(img, cmap=plt.cm.binary)

  predicted_label = np.argmax(predictions_array)
  if predicted_label == true_label:
    color = 'blue'
  else:
    color = 'red'

  plt.xlabel("{} {:2.0f}% ({})".format(class_names[predicted_label],
                                100*np.max(predictions_array),
                                class_names[true_label]),
                                color=color)


def plot_value_array(i, predictions_array, true_label):
  predictions_array, true_label = predictions_array, true_label[i]
  plt.grid(False)
  plt.xticks(range(10))
  plt.yticks([])
  thisplot = plt.bar(range(10), predictions_array, color="#777777")
  plt.ylim([0, 1])
  predicted_label = np.argmax(predictions_array)

  thisplot[predicted_label].set_color('red')
  thisplot[true_label].set_color('blue')

i = 0
plt.figure(figsize=(6, 3))
plt.subplot(1, 2, 1)
plot_image(i, predictions[i], test_labels, test_images)
plt.subplot(1, 2, 2)
plot_value_array(i, predictions[i], test_labels)
plt.show()

i = 12
plt.figure(figsize=(6,3))
plt.subplot(1,2,1)
plot_image(i, predictions[i], test_labels, test_images)
plt.subplot(1,2,2)
plot_value_array(i, predictions[i],  test_labels)
plt.show()


# Plot the first X test images, their predicted labels, and the true labels.
# Color correct predictions in blue and incorrect predictions in red.
num_rows = 5
num_cols = 3
num_images = num_rows*num_cols
plt.figure(figsize=(2*2*num_cols, 2*num_rows))
for i in range(num_images):
  plt.subplot(num_rows, 2*num_cols, 2*i+1)
  plot_image(i, predictions[i], test_labels, test_images)
  plt.subplot(num_rows, 2*num_cols, 2*i+2)
  plot_value_array(i, predictions[i], test_labels)
plt.tight_layout()
plt.show()


# Grab an image from the test dataset.
img = test_images[1]

print(img.shape)


# Add the image to a batch where it's the only member.
img = (np.expand_dims(img,0))

print(img.shape)


predictions_single = probability_model.predict(img)

print(predictions_single)


plot_value_array(1, predictions_single[0], test_labels)
_ = plt.xticks(range(10), class_names, rotation=45)

print(np.argmax(predictions_single[0]))

参考文章

TensorFlow——本地加载fashion-mnist数据集

TensorFlow 教程——基本分类:对服装图像进行分类

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Starzkg

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值