tensorflow笔记【8】深度学习-预测
提示:以下是本篇文章正文内容,下面案例可供参考
一、应用程序,给图识物
predict
predict(x, batch_size=None, verbose=0, steps=None)
为输入样本生成输出预测。
计算是分批进行的
参数
x: 输入数据,Numpy 数组 (或者 Numpy 数组的列表,如果模型有多个输出)。
batch_size: 整数。如未指定,默认为 32。
verbose: 日志显示模式,0 或 1。
steps: 声明预测结束之前的总步数(批次样本)。默认值 None。
返回
预测的 Numpy 数组(或数组列表)。
二、使用步骤
1.复现模型,前向传播
2.load_weight读取已有参数
3.使用predicth函数根据输入特征输出预测结果
代码如下:
# 1.导入相关模块---import
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import tensorflow as tf
import os
from tensorflow.keras.layers import Flatten, Dense
from tensorflow.keras import Model
# 2.搭建网络模型----class
model_save_path = './checkpoint/checkpoint.ckpt' # 最优模型保存路径
class MnistModel(Model):
def