keras里predict函数预测速度慢的优化方法

需求分析:在keras模型中,使用predict函数对1.9kw个样本进行预测,但是速度较慢

(1)tensorflow版本:

import tensorflow as tf
from tensorflow import keras
print(tf.__version__)
2.0.0

(2)导入模型及数据:

model = keras.models.load_model('../../CLDNN/save_model/**') 
test_input = np.array(st_f[0:0+10000])

(3)不同的预测方式:

clock = time.time()
test = model.predict(test_input)  # predict函数,如果不指定batch_size, 默认是32
print(time.time() - clock,'s')
3.033808946609497 s
clock = time.time()
# model.predict(np.array(st_input[k:k+100]))
model.predict(test_input,batch_size=len(test_input))  # predict函数,如果不指定batch_size, 默认是32
print(time.time() - clock,'s')
0.5662539005279541 s
clock = time.time()
test3 = model(test_input, training=False)
print(time.time() - clock,'s')
0.41847729682922363 s
clock = time.time()
test4 = model.predict(tf.data.Dataset.from_tensors(test_input))
print(time.time() - clock,'s')
0.5084314346313477 s

(4)总结:
关于预测,使用model(test_input, training=False)的速度最优。另外,当predict函数的输入指定为tensor时,无法设置batch_size。

参考:tensorflow 2.5中,采用keras,predict很慢,300倍减少predict运行时间的优化方法

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值