标签: tensorflow
原文链接
代码链接
上一篇我们看了使用tf.estimator直接构建一个DNN分类器,但是数据load进来之后,输入分类器之前,还要经过一个input_fn的函数。
这篇文章会教你怎么用input_fn来喂给一个神经网络回归器数据。
1.把feature data转换为tensor
如果你的feature/label数据是python array ,或者存在pandas dataframe 或者numpy array中,可以用下面的方法构建inputfn函数
pass inputfn data to your model
直接把input function作为参数输入train op,注意input fn是作为一个object 传入,而不是作为一个函数被调用,要不然会发生typeerror
即使要修改inputfn的参数,也不能这样用,有其他的方法
classifier.train(input_fn=my_input_fn, steps=2000) 正确
classifier.train(input_fn=my_input_fn(training_set), steps=2000) 错误
即inputfn在输入的时候不能调参,必须在定义的时候被设置到加载哪个数据集
若果不想重复定义,比如inputfntrain,inputfntest,inputfnevaluate
可以用以下四种方法
(1)、用一个包装函数
my_input_fn_training_set(),感觉也没简单多少
def my_input_fn(data_set):
...
def my_input_fn_training_set():
return my_input_fn(training_set)
classifier.train(input_fn=my_input_fn_training_set, steps=2000)