接上一节,本使用AlexNet8网络预测cifar10数据集中分类
1)构建网络,训练,保存模型文件
cifar10_alexnet8_sequential.py
import tensorflow as tf
import os
import numpy as np
from matplotlib import pyplot as plt
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense
from tensorflow.keras import Model
np.set_printoptions(threshold=np.inf)
cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
#AlexNet模型
model = tf.keras.models.Sequential([
# 网络结构
Conv2D(filters=96, kernel_size=(3, 3)),
BatchNormalization(),
Activation('relu'),
MaxPool2D(pool_size=(3, 3), strides=2),
Conv2D(filters=256, kernel_size=(3, 3)),
BatchNormalization(),
Activation('relu'),
MaxPool2D(pool_size=(3, 3), strides=2),
Conv2D(filters=384, kernel_size=(3, 3), padding='same', activation='relu'),
Conv2D(filters=384, kernel_size=(3, 3), padding='same', activation='relu'),
Conv2D(filters=256, kernel_size=(3, 3), padding='same', activation='relu'),
MaxPool2D(pool_size=(3, 3), strides=2),
Flatten(),
Dense(2048, activation='relu'),
Dropout(0.5),
Dense(2048, activation='relu'),
Dropout(0.5),
Dense(10, activation='softmax')
])
#=============================上面网络================================
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
checkpoint_save_path = "./mycheckpoint/AlexNet8.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
print('-------------load the model-----------------')
model.load_weights(checkpoint_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
save_weights_only=True,
save_best_only=True)
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
callbacks=[cp_callback])
model.summary()
# print(model.trainable_variables)
# file = open('./weights.txt', 'w')
# for v in model.trainable_variables:
# file.write(str(v.name) + '\n')
# file.write(str(v.shape) + '\n')
# file.write(str(v.numpy()) + '\n')
# file.close()
############################################### show ###############################################
# 显示训练集和验证集的acc和loss曲线
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
2)加载网络模型(.ckpt),预测自我图片
cifar10_alexnet8_app.py
from PIL import Image
import numpy as np
import tensorflow as tf
import cv2
from matplotlib import pyplot as plt
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense
import print_category_self as pc
model_save_path = './mycheckpoint/AlexNet8.ckpt'
# 复现网络
model = tf.keras.models.Sequential([
# 网络结构
Conv2D(filters=96, kernel_size=(3, 3)),
BatchNormalization(),
Activation('relu'),
MaxPool2D(pool_size=(3, 3), strides=2),
Conv2D(filters=256, kernel_size=(3, 3)),
BatchNormalization(),
Activation('relu'),
MaxPool2D(pool_size=(3, 3), strides=2),
Conv2D(filters=384, kernel_size=(3, 3), padding='same', activation='relu'),
Conv2D(filters=384, kernel_size=(3, 3), padding='same', activation='relu'),
Conv2D(filters=256, kernel_size=(3, 3), padding='same', activation='relu'),
MaxPool2D(pool_size=(3, 3), strides=2),
Flatten(),
Dense(2048, activation='relu'),
Dropout(0.5),
Dense(2048, activation='relu'),
Dropout(0.5),
Dense(10, activation='softmax')
])
# 加载参数
model.load_weights(model_save_path)
#测试自己下载的一张图片iamairplan.jpg
img = cv2.imread('./iamairplan.jpg')
print(img.shape)
plt.imshow(img, cmap='Greys')
plt.show()
# 将所给图片变换成32X32大小
# 可以看到,刚读出来的图片是 3 个通道的彩图;我们上面训练的也使用的 3 通道彩图;
# 所以我们要对这个图片进行 resize;但是 resize 操作不能直接对 3 通道的图片做;所以:
# 我们按照 opencv 读图片的通道顺序 b, g, r (注意不是 rgb) 使用 cv2.split() 函数对数据解包;得到了每个通道之后我们分别做 resize 操作,最后再用 cv2.merge() 将三个通道叠加起来;这样我们就可以得到我们想要的结果了
b,g,r = cv2.split(img)
print(b.shape,g.shape,r.shape)
b_resize = cv2.resize(b,(32,32))
g_resize = cv2.resize(g,(32,32))
r_resize = cv2.resize(r,(32,32))
new_img = cv2.merge((b_resize,g_resize,r_resize))
print(new_img)
print(new_img.shape)
plt.imshow(new_img, cmap='Greys')
plt.show()
#归一化
new_img = new_img / 255.0
#把矩阵转化为4维
input_img = new_img.reshape(1,32,32,3)
print(input_img.shape)
#进行预测
result = model.predict(input_img)
print(result)
# # 输出最大预测值。
pred = tf.argmax(result, axis=1)
print('\n')
tf.print(pred)# 预测结构为【0】,代表飞机,预测正确。
#上面已经结束了,下面将输出值配对名称输出。调用print_category_self.py中的print_category();
category = str(pred.numpy())
print(category)
pc.print_category(category)
print("\n")
下面的文件就是参考文章的预测处理,将预测输出值,与对应类别名字匹配
print_category.py
def print_category(category):
if category == '[0]':
print('飞机')
elif category == '[1]':
print('汽车')
elif category == '[2]':
print('鸟')
elif category == '[3]':
print('猫')
elif category == '[4]':
print('鹿')
elif category == '[5]':
print('狗')
elif category == '[6]':
print('青蛙')
elif category == '[7]':
print('马')
elif category == '[8]':
print('船')
else:
print('卡车')
写在最后:欢迎大家进行指导和讨论