文章目录
一、函数详解
在 Keras 中,获取模型的预测结果的两种方式:
keras_model() 直接调用模型对象
:将 Keras 模型对象当作函数一样调用,并将输入数据作为参数传递给它,从而直接获取预测结果。
- 优缺点:(1)支持动态图计算;(2)只支持单样本预测;(3)只支持Tensor类型的输入数据;(4)输出数据为Tensor类型;
- 适用范围:大规模数据;实时处理;预测速度快
keras_model.predict() 方法:
:predict() 方法是 Keras 模型对象的一个函数,用于进行推理并获取预测结果。
- 优缺点:(1)不支持动态图计算;(2)支持批量样本预测;(3)支持Tensor和NumPy类型的输入数据;(4)输出数据为NumPy类型;(5)需要一次性将所有数据加载到内存中,因此对于大型数据集,可能会导致内存不足。
- 适用范围:小规模数据;对内存占用不敏感;预测速度慢
在 PyTorch 中,只有一种方法获取模型的预测结果:
pytorch_model()
。
1.1、keras_model.predict(x)
"""###################################################################
函数说明: keras_model.predict(x, batch_size=None, verbose=0, steps=None, callbacks=None, max_queue_size=10, workers=1, use_multiprocessing=False)
输入参数:
(1)x: 输入数据。可以是 NumPy 数组、tf.data.Dataset 对象或字典等。具体取决于模型的输入层的期望。
(2)batch_size: 批处理大小。如果未指定,将使用默认的批处理大小。批处理大小表示模型一次性处理的样本数量,可以影响内存使用和预测速度。
(3)verbose: 是否显示进度信息(0表示不显示任何信息、1表示显示进度条、2表示每个epoch显示一行)
(4)steps: 指定预测结束的步数(批次数)。如果未指定,将一直进行预测,直到输入数据用尽。
(5)callbacks: 可选的回调函数列表。在预测过程中的不同阶段触发不同的回调函数,用于自定义行为。
(6)max_queue_size: 指定生成器队列的最大大小。对于生成器提供的数据,此参数可以控制内存使用。
(7)workers: 指定生成器的工作进程数量。仅在使用生成器提供数据时才相关。
(8)use_multiprocessing: 布尔值,表示是否使用多进程进行数据生成。默认为 False。如果设置为 True,则会使用 workers 个进程进行数据生成。
返回参数:
numpy数组
###################################################################"""
1.2、keras_model(x)
"""###################################################################
函数说明: keras_model(x, training=None)
输入参数:
(1)x: 输入数据。可以是 Numpy 数组、Tensor 对象或者其他可以被模型接受的输入数据。这里输入数据的形状和数据类型需要和模型的输入层相匹配。
(2)training: 布尔值,表示模型是否处于训练模式。
True(训练模式): 如:启用Dropout、 启用Batch Normalization
False(推理或预测模式): 如:不启用Dropout、 不启用Batch Normalization
返回参数:
tensor张量
###################################################################"""
二、加速测试
keras里predict函数,预测速度慢的优化方法
keras里predict很慢,300倍减少predict运行时间的优化方法
2.1、model.predict(x=input_data) —— 时耗:0.09967 秒
2.2、model.predict(x=input_data, batch_size=8) —— 时耗:0.12919 秒
2.3、model.predict(tf.data.Dataset.from_tensors(input_data)) —— 时耗:0.08310 秒
2.4、model(x=input_data, training=False) —— 时耗:0.01395 秒
import time
import numpy as np
import tensorflow as tf
if __name__ == "__main__":
print("Tensorflow版本 =", tf.__version__)
##############################################################################
# (1)新建模型
flag = 2
if flag == 1:
"""1.序列模型"""
from tensorflow import keras
model = keras.Sequential([
keras.layers.Dense(128, activation='relu', input_shape=(10,)), # 定义第一个全连接层,包括 128 个神经元,使用 ReLU 激活函数,输入形状为 (10,)。
keras.layers.Dense(64, activation='relu'), # 定义第二个全连接层,包括 64 个神经元,使用 ReLU 激活函数。
keras.layers.Dense(1, activation='sigmoid') # 定义输出层,包括 1 个神经元,使用 Sigmoid 激活函数,适用于二分类问题。
])
input_data = np.random.rand(1000, 10) # 随机生成输入数据,其中包含 1000 个样本,每个样本有 10 个特征。
elif flag == 2:
"""2.卷积模型"""
from tensorflow.keras import layers, models
model = models.Sequential([
# 卷积层,包含 32 个卷积核(filters),每个卷积核大小为 (3, 3),使用 ReLU 激活函数。input_shape=(100, 100, 3) 表示输入图像的形状为 (100, 100, 3)
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(100, 50, 3)),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
# 将卷积层输出的多维数据展平为一维,为全连接层做准备。
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(1, activation='sigmoid')
])
input_data = np.random.rand(1000, 100, 50, 3) # 随机生成输入数据,其中包含 1000 个样本,每个图像为100x100x3(RGB)。
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) # 编译模型: 指在使用模型进行训练之前,配置模型参数。
model.summary() # 打印模型概要
##############################################################################
# (2)多种预测方式的时耗
for ii in range(3):
clock = time.time()
res1 = model.predict(x=input_data)
elapsed_time = time.time() - clock
print(f"总共耗时1: {elapsed_time:.5f} 秒")
clock = time.time()
res2 = model.predict(x=input_data, batch_size=8)
elapsed_time = time.time() - clock
print(f"总共耗时2: {elapsed_time:.5f} 秒")
clock = time.time()
test3 = model.predict(tf.data.Dataset.from_tensors(input_data))
elapsed_time = time.time() - clock
print(f"总共耗时3: {elapsed_time:.5f} 秒")
clock = time.time()
res4 = model(x=input_data, training=False)
elapsed_time = time.time() - clock
print(f"总共耗时4: {elapsed_time:.5f} 秒")
print(" ")
"""
总共耗时1: 0.09967 秒
总共耗时2: 0.12919 秒
总共耗时3: 0.08310 秒
总共耗时4: 0.01395 秒
"""