# add export graph def serving_input_fn(): input_ids = tf.placeholder(dtype=tf.int32, shape=[None, FLAGS.max_seq_length], name="input_ids") input_mask = tf.placeholder(dtype=tf.int32, shape=[None, FLAGS.max_seq_length], name="input_mask") segment_ids = tf.placeholder(dtype=tf.int32, shape=[None, FLAGS.max_seq_length], name="segment_ids") label_ids = tf.placeholder(dtype=tf.int32, shape=[None], name="label_ids") string_2 = tf.placeholder(dtype=tf.string, shape=[None], name="string_2") features = {'input_ids': input_ids, 'input_mask': input_mask, 'segment_ids': segment_ids, 'label_ids': label_ids, 'string_2': string_2} return tf.contrib.learn.InputFnOps(features, None, default_inputs=features) export_dir = os.path.join(FLAGS.output_dir, "saved_model") estimator.export_savedmodel(export_dir, serving_input_fn(), assets_extra={"vocab.txt": FLAGS.vocab_file}, as_text=False, strip_default_attrs=True)
Tensorboard estimator export_savedmodel简单用法
最新推荐文章于 2024-05-15 21:18:39 发布