tensorflow2.0 (2)使用AlexNet8网络预测cifar10数据集中分类

接上一节,本使用AlexNet8网络预测cifar10数据集中分类

参考文章使用AlexNet8网络实现10分类

1)构建网络,训练,保存模型文件

cifar10_alexnet8_sequential.py

import tensorflow as tf
import os
import numpy as np
from matplotlib import pyplot as plt
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense
from tensorflow.keras import Model

np.set_printoptions(threshold=np.inf)

cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
#AlexNet模型
model = tf.keras.models.Sequential([
    # 网络结构
    Conv2D(filters=96, kernel_size=(3, 3)),
    BatchNormalization(),
    Activation('relu'),
    MaxPool2D(pool_size=(3, 3), strides=2),

    Conv2D(filters=256, kernel_size=(3, 3)),
    BatchNormalization(),
    Activation('relu'),
    MaxPool2D(pool_size=(3, 3), strides=2),

    Conv2D(filters=384, kernel_size=(3, 3), padding='same', activation='relu'),

    Conv2D(filters=384, kernel_size=(3, 3), padding='same', activation='relu'),

    Conv2D(filters=256, kernel_size=(3, 3), padding='same', activation='relu'),
    MaxPool2D(pool_size=(3, 3), strides=2),

    Flatten(),
    Dense(2048, activation='relu'),
    Dropout(0.5),
    Dense(2048, activation='relu'),
    Dropout(0.5),
    Dense(10, activation='softmax')
])
#=============================上面网络================================



model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

checkpoint_save_path = "./mycheckpoint/AlexNet8.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model-----------------')
    model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True)

history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
                    callbacks=[cp_callback])
model.summary()

# print(model.trainable_variables)
# file = open('./weights.txt', 'w')
# for v in model.trainable_variables:
#     file.write(str(v.name) + '\n')
#     file.write(str(v.shape) + '\n')
#     file.write(str(v.numpy()) + '\n')
# file.close()

###############################################    show   ###############################################

# 显示训练集和验证集的acc和loss曲线
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

2)加载网络模型(.ckpt),预测自我图片

cifar10_alexnet8_app.py

from PIL import Image
import numpy as np
import tensorflow as tf
import  cv2
from matplotlib import pyplot as plt
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense
import print_category_self as pc
model_save_path = './mycheckpoint/AlexNet8.ckpt'

# 复现网络
model = tf.keras.models.Sequential([
    # 网络结构
    Conv2D(filters=96, kernel_size=(3, 3)),
    BatchNormalization(),
    Activation('relu'),
    MaxPool2D(pool_size=(3, 3), strides=2),

    Conv2D(filters=256, kernel_size=(3, 3)),
    BatchNormalization(),
    Activation('relu'),
    MaxPool2D(pool_size=(3, 3), strides=2),

    Conv2D(filters=384, kernel_size=(3, 3), padding='same', activation='relu'),

    Conv2D(filters=384, kernel_size=(3, 3), padding='same', activation='relu'),

    Conv2D(filters=256, kernel_size=(3, 3), padding='same', activation='relu'),
    MaxPool2D(pool_size=(3, 3), strides=2),

    Flatten(),
    Dense(2048, activation='relu'),
    Dropout(0.5),
    Dense(2048, activation='relu'),
    Dropout(0.5),
    Dense(10, activation='softmax')
])

# 加载参数
model.load_weights(model_save_path)

#测试自己下载的一张图片iamairplan.jpg
img = cv2.imread('./iamairplan.jpg')
print(img.shape)
plt.imshow(img, cmap='Greys')
plt.show()
# 将所给图片变换成32X32大小
# 可以看到,刚读出来的图片是 3 个通道的彩图;我们上面训练的也使用的 3 通道彩图;
# 所以我们要对这个图片进行 resize;但是 resize 操作不能直接对 3 通道的图片做;所以:
# 我们按照 opencv 读图片的通道顺序 b, g, r (注意不是 rgb) 使用 cv2.split() 函数对数据解包;得到了每个通道之后我们分别做 resize 操作,最后再用 cv2.merge() 将三个通道叠加起来;这样我们就可以得到我们想要的结果了
b,g,r = cv2.split(img)
print(b.shape,g.shape,r.shape)

b_resize = cv2.resize(b,(32,32))
g_resize = cv2.resize(g,(32,32))
r_resize = cv2.resize(r,(32,32))

new_img = cv2.merge((b_resize,g_resize,r_resize))
print(new_img)
print(new_img.shape)
plt.imshow(new_img, cmap='Greys')
plt.show()


#归一化
new_img = new_img / 255.0
#把矩阵转化为4维
input_img = new_img.reshape(1,32,32,3)
print(input_img.shape)

#进行预测
result = model.predict(input_img)
print(result)

# # 输出最大预测值。
pred = tf.argmax(result, axis=1)

print('\n')
tf.print(pred)# 预测结构为【0】,代表飞机,预测正确。

#上面已经结束了,下面将输出值配对名称输出。调用print_category_self.py中的print_category();
category = str(pred.numpy())
print(category)

pc.print_category(category)
print("\n")

下面的文件就是参考文章的预测处理,将预测输出值,与对应类别名字匹配

print_category.py

def print_category(category):
    if category == '[0]':
        print('飞机')
    elif category == '[1]':
        print('汽车')
    elif category == '[2]':
        print('鸟')
    elif category == '[3]':
        print('猫')
    elif category == '[4]':
        print('鹿')
    elif category == '[5]':
        print('狗')
    elif category == '[6]':
        print('青蛙')
    elif category == '[7]':
        print('马')
    elif category == '[8]':
        print('船')
    else:
        print('卡车')
写在最后:欢迎大家进行指导和讨论
  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值