Spark DeepLeaning4j加载Keras模型
深度学习模型在实际业务应用中,经常会结合Spark集群来处理企业内部数据。由于网络通信开销,一些模型在Spark上进行分布式训练,效率十分低下;这种情况可考虑线下使用GPU机器训练模型,再将训练后的模型部署到Spark集群上。
通过本文,你将会了解到Spark DeepLearning4j如何加载Keras模型。下面是具体的实现步骤:
- 线下训练环境中,保存Keras模型。
model.save('full_model.h5')
- 使用KerasModelImport函数,本地加载Keras模型到deeplearning4j中,并使用ModelSerializer函数保存deeplearning4j模型。(spark cluster不能从hdfs中直接加载Keras的h5模型文件)
val path = "full_model.h5"
val net = KerasModelImport.