1、定义训练模型
加载模型结构,模型结构必须与训练的模型结构一致。
model = tf.keras.models.Sequential([ tf.keras.layers.Conv2D(16, (3,3), activation='relu', input_shape=(150, 150, 3)), tf.keras.layers.MaxPooling2D(2,2), tf.keras.layers.Conv2D(32, (3,3), activation='relu'), tf.keras.layers.MaxPooling2D(2,2), tf.keras.layers.Conv2D(64, (3,3), activation='relu'), tf.keras.layers.MaxPooling2D(2,2), tf.keras.layers.Flatten(), tf.keras.layers.Dense(512, activation='relu'), tf.keras.layers.Dense(1, activation='sigmoid') ])
2、加载模型权重
加载训练保存的cp.ckpt文件,训练过程可以参考前一篇文章“基于卷积神经网络(CNN)的猫狗识别系统”。
checkpoint_path = "D:\pycharm\pythonProject1\cp.ckpt" model.load_weights(checkpoint_path)
3、加载并预测图片
def predict_image(file_path): img = image.load_img(file_path, target_size=(150, 150)) x = image.img_to_array(img) x = np.expand_dims(x, axis=0) images = np.vstack([x]) classes = model.predict(images, batch_size=10) return "狗" if classes[0]>0.5 else "猫"
4、页面搭建
使用Qt-designer搭建一个简单的操作页面。
from home_ui import Ui_MainWindow as window from PySide2 import QtWidgets, QtGui import sys # from judgeImage import JudgeImage as model import judgeImage as model class mainWindow(QtWidgets.QMainWindow, window): def __init__(self): super().__init__() self.file_name = None self.setupUi(self) # 选择文件按钮的点击事件连接到自定义方法 self.selectFile.clicked.connect(self.handle_selectIo) # 开始预测按钮的点击事件连接到自定义方法 self.biginTrain.clicked.connect(self.train) # 点击选择文件按钮时执行的操作 def handle_selectIo(self): # 设置文件过滤器 filters = "jpg files (*.jpg);;png files (*.png))" # 获取用户选择的文件名 self.file_name, _ = QtWidgets.QFileDialog.getOpenFileName( self, "选择文件", "", # 起始路径,空字符串表示当前目录 filters # 文件过滤器 ) # 将文件名显示在界面上 self.fileName.setText(self.file_name) # 显示所选文件的图像 self.display_image() # 显示所选文件的图像 def display_image(self): pixmap = QtGui.QPixmap(self.file_name) self.imageView.setPixmap(pixmap.scaled(self.imageView.size())) # 点击开始预测按钮时执行的操作 def train(self): # 如果用户没有选择文件,弹出警告提示框并返回 if self.file_name is None: QtWidgets.QMessageBox.warning(self, "警告", "请先选择文件!") return # 实例化 judgeImage 类并进行预测 prediction = model.predict_image(self.file_name) print("预测结果为:", prediction) self.result.setText(prediction) if __name__ =='__main__': app = QtWidgets.QApplication(sys.argv) temp = mainWindow() temp.show() sys.exit(app.exec_())