训练模型就是为了让其作预测. 笔记如下.
model_fn
预测的原理是, tf 进程构建了预测 mode下的计算图, 然后从 model_dir 中恢复变量, 就绪后作预测.
# _model_fn 定义
def _model_fn(features, # This is batch_features from input_fn
labels, # This is batch_labels from input_fn
mode, # An instance of tf.estimator.ModeKeys
params):
# ...
if mode == tf.estimator.ModeKeys.PREDICT:
predictions = dict()
score = tf.sigmoid(logits)
predictions['sample_key'] = sample_key_tensor
predictions['score'] = score
# predictions['doubtful_tensor'] = score
estimator_spec = tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.PREDICT,
predictions=predictions
)
# ...
遍历预测结果
# 遍历预测结果
pred_result_generator = estimator.predict(input_fn=pred_input_fn)
for pred_no, pred_dict in enumerate(pred_result_generator): # 不再是batch_size
sample_key = pred_dict['sample_key'] # 单个int
pred_ctr = pred_dict['score'] # 单个float
print(sample_key, pred_ctr)
常见问题
- 计算图要严格对应
若不能严格一致
/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1417, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.NotFoundError