1、定义模型
def my_model(features,labels,mode,params):#此处的features是真实的数据,不是特征列
return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
2、estimator包裹自定义模型
classifier = tf.estimator.Estimator( model_fn=my_model, params={ 'feature_columns': my_feature_columns,#包含特征列的列表 'hidden_units': FLAGS.hidden_units.split(','), 'learning_rate': FLAGS.learning_rate, 'dropout_rate': FLAGS.dropout_rate }, config=tf.estimator.RunConfig(model_dir=FLAGS.model_dir, save_checkpoints_steps=FLAGS.save_checkpoints_steps) )
3、自定义生成数据的函数
这里需要注意的是:如果自定义模型时采用
net = fc.input_layer(features, params['feature_columns'])
输入数据时,而不是tf.placeholder的方式时,需要用以下方式自定义生成数据的函数
def parse_exmp(serial_exmp): click = fc.numeric_column("click", default_value=0, dtype=tf.int64) pay = fc.numeric_column("pay", default_value=0, dtype=tf.int64) fea_columns = [click, pay] fea_columns += my_feature_columns feature_spec = tf.feature_column.make_parse_example_spec(fea_columns) #把数据映射过来 feats = tf.parse_single_example(serial_exmp, features=feature_spec) click = feats.pop('click') pay = feats.pop('pay') return feats, {'ctr': tf.to_float(click), 'cvr': tf.to_float(pay)} def train_input_fn(filenames, batch_size, shuffle_buffer_size): #dataset = tf.data.TFRecordDataset(filenames) files = tf.data.Dataset.list_files(filenames) dataset = files.apply(tf.contrib.data.parallel_interleave(tf.data.TFRecordDataset, cycle_length=FLAGS.num_parallel_readers)) # Shuffle, repeat, and batch the examples. if shuffle_buffer_size > 0: dataset = dataset.shuffle(shuffle_buffer_size) #dataset = dataset.apply(tf.contrib.data.map_and_batch(map_func=parse_exmp, batch_size=batch_size)) #dataset = dataset.repeat().prefetch(1) dataset = dataset.map(parse_exmp, num_parallel_calls=8) dataset = dataset.repeat().batch(batch_size).prefetch(1) print(dataset.output_types) print(dataset.output_shapes) # Return the read end of the pipeline. return dataset
4、生成TrainSpec和Eval_Spec
#train_files 为tfrecord格式
train_spec = tf.estimator.TrainSpec( input_fn=lambda: train_input_fn(train_files, batch_size, shuffle_buffer_size), max_steps=FLAGS.train_steps )
input_fn_for_eval = lambda: eval_input_fn(eval_files, batch_size) eval_spec = tf.estimator.EvalSpec(input_fn=input_fn_for_eval, throttle_secs=600, steps=None)
5、训练并且评估
tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec)
6、获得预测结果
results = classifier.evaluate(input_fn=input_fn_for_eval)