1.主要参考tf的demo
2.源码
import tensorflow as tf
import pandas as pd
CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth', 'Species']
SPECIES = ['Setosa', 'Versicolor', 'Virginica']
#训练样本和测试样本
#"http://download.tensorflow.org/data/iris_training.csv"
#"http://download.tensorflow.org/data/iris_test.csv"
tf.logging.set_verbosity(tf.logging.INFO)
train = pd.read_csv('./iris_training.csv', names=CSV_COLUMN_NAMES, header=0)
train_y = train.pop('Species')
train_x = train
def input_fn_1(features, labels, batch_size):
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
return dataset
def input_fn_2(features, labels, batch_size):
features = dict(features)
if labels is None: