[tensorflow]tf.keras入门2-分类

目录

Fashion MNIST数据库

分类模型的建立

模型预测

总体代码


主要介绍基于tf.keras的Fashion MNIST数据库分类,

官方文档地址为:https://tensorflow.google.cn/tutorials/keras/basic_classification

文本分类类似,官网文档地址为https://tensorflow.google.cn/tutorials/keras/basic_text_classification

首先是函数的调用,对于tensorflow只有在版本1.2以上的版本才有tf.keras库。另外推荐使用python3,而不是python2。

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras

# 其他库
import numpy as np
import matplotlib.pyplot as plt
#查看版本
print(tf.__version__)
#1.9.0

Fashion MNIST数据库

fashion mnist数据库是mnist数据库的一个拓展。目的是取代mnist数据库,类似MINST数据库,fashion mnist数据库为训练集60000张,测试集10000张的28X28大小的服装彩色图片。具体分类如下:

标注编号描述
0T-shirt/top(T恤)
1Trouser(裤子)
2Pullover(套衫)
3Dress(裙子)
4Coat(外套)
5Sandal(凉鞋)
6Shirt(汗衫)
7Sneaker(运动鞋)
8Bag(包)
9Ankle boot(踝靴)

样本描述如下:

名称描述样本数量文件大小链接
train-images-idx3-ubyte.gz训练集的图像60,00026 MBytes下载
train-labels-idx1-ubyte.gz训练集的类别标签60,00029 KBytes下载
t10k-images-idx3-ubyte.gz测试集的图像10,0004.3 MBytes下载
t10k-labels-idx1-ubyte.gz测试集的类别标签10,0005.1 KBytes下载

单张图像展示代码:

#分类标签
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
#单张图像展示,推荐使用python3
plt.figure()
plt.imshow(train_images[0])
#添加颜色渐变条
plt.colorbar()
#不显示网格线
plt.gca().grid(False)

效果图:

样本的展示代码:

#图像预处理
train_images = train_images / 255.0
test_images = test_images / 255.0

#样本展示
plt.figure(figsize=(10,10))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid('off')
    plt.imshow(train_images[i], cmap=plt.cm.binary)
    plt.xlabel(class_names[train_labels[i]])

效果图:

分类模型的建立

检测模型输入数据为28X28,1个隐藏层节点数为128,输出类别10类,代码如下:

#检测模型    
model = keras.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(128, activation=tf.nn.relu),
    keras.layers.Dense(10, activation=tf.nn.softmax)
])    

模型训练参数设置:

model.compile(optimizer=tf.train.AdamOptimizer(), 
          loss='sparse_categorical_crossentropy', #多分类的对数损失函数
          metrics=['accuracy']) #准确度

模型的训练:

model.fit(train_images, train_labels, epochs=5)

模型预测

预测函数:

predictions = model.predict(test_images)

分类器是softmax分类器,输出的结果一个predictions是一个长度为10的数组,数组中每一个数字的值表示其所对应分类的概率值。如下所示:

predictions[0]
array([2.1840347e-07, 1.9169457e-09, 4.5915922e-08, 5.3185740e-08,
       6.6372898e-08, 2.6090498e-04, 6.5197796e-06, 4.7861701e-03,
       2.9425648e-06, 9.9494308e-01], dtype=float32)

对于predictions[0]其中第10个值最大,则该值对应的分类为class[9]ankle boot。

np.argmax(predictions[0]) #9
test_labels[0] #9

前25张图的分类效果展示:

#前25张图分类效果
plt.figure(figsize=(10,10))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid('off')
    plt.imshow(test_images[i], cmap=plt.cm.binary)
    predicted_label = np.argmax(predictions[i])
    true_label = test_labels[i]
    if predicted_label == true_label:
      color = 'green'
    else:
      color = 'red'
    plt.xlabel("{} ({})".format(class_names[predicted_label], 
                                  class_names[true_label]),
                                  color=color)

效果图,绿色标签表示分类正确,红色标签表示分类错误:

对于单个图像的预测,需要将图像28X28的输入转换为1X28X28的输入,转换函数为np.expand_dims。函数使用如下:https://www.zhihu.com/question/265545749

#格式转换
img = (np.expand_dims(img,0))
print(img.shape) #1X28X28

predictions = model.predict(img)
prediction = predictions[0]
np.argmax(prediction) #9

总体代码

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras

# 其他库
import numpy as np
import matplotlib.pyplot as plt
#查看版本
print(tf.__version__)
#1.9.0

fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

#分类标签
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
#单张图像展示,推荐使用python3
plt.figure()
plt.imshow(train_images[0])
#添加颜色渐变条
plt.colorbar()
#不显示网格线
plt.gca().grid(False)

#图像预处理
train_images = train_images / 255.0
test_images = test_images / 255.0

#样本展示
plt.figure(figsize=(10,10))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid('off')
    plt.imshow(train_images[i], cmap=plt.cm.binary)
    plt.xlabel(class_names[train_labels[i]])

#检测模型    
model = keras.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(128, activation=tf.nn.relu),
    keras.layers.Dense(10, activation=tf.nn.softmax)
])    

model.compile(optimizer=tf.train.AdamOptimizer(), 
          loss='sparse_categorical_crossentropy', #多分类的对数损失函数
          metrics=['accuracy']) #准确度

model.fit(train_images, train_labels, epochs=5)

predictions = model.predict(test_images)

#前25张图分类效果
plt.figure(figsize=(10,10))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid('off')
    plt.imshow(test_images[i], cmap=plt.cm.binary)
    predicted_label = np.argmax(predictions[i])
    true_label = test_labels[i]
    if predicted_label == true_label:
      color = 'green'
    else:
      color = 'red'
    plt.xlabel("{} ({})".format(class_names[predicted_label], 
                                  class_names[true_label]),
                                  color=color)
    
#单个图像检测
img = test_images[0]
print(img.shape) #28X28

#格式转换
img = (np.expand_dims(img,0))
print(img.shape) #1X28X28

predictions = model.predict(img)
prediction = predictions[0]
np.argmax(prediction) #9

 

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值