Tensorflow 1.4中,Keras作为作为核心模块可以直接通过tf.keas进行调用,但是考虑到keras对tfrecords文件进行操作比较麻烦,而将keras模型转成tensorflow中的另一个高级API -- Estimator模型,然后就可以调用Dataset API进行对tfrecords进行操作用来训练/评估模型。而keras本身也用到了Estimator API并且提供了tf.keras.estimator.model_to_estimator函数将keras模型可以很方便的转换成Estimator模型,因此用Keras API搭建模型框架然后用Dataset API操作IO,然后用Estimator训练模型是一套比较方便高效的操作流程。
注:
tf.keras.estimator.model_to_estimator这个函数只在tf.keras下面有在原生的keras中是没有这个函数的。
Estimator训练的模型类型主要有regressor和classifier两类,如果需要用自定义的模型类型,可以通过自定有model_fn来构建,具体操作可以查看这里
Estimator模型可以通过export_savedmodel()函数输出训练好的estimator模型,然后可以把模型创建服务接受输入数据并输出结果,这在大规模云端部署的时候会非常有用(具体操作流程可以看这里)。
1. 利用Keras搭建模型框架并转换成estimator模型
比如我们利用keras的ResNet50构建二分类模型:
import tensorflow as tf
import os
resnet = tf.keras.applications.resnet50
def my_model_fn():
base_model = resnet.ResNet50(include_top=True, # include fully layers or not
weights='imagenet', # pre-trained weights
input_shape=(224, 224, 3), # default input shape
classes=2)
base_model.summary()
optimizer = tf.keras.optimizers.RMSprop(lr=2e-3,
d