5.1_TensorFlow进阶篇
使用TensorFlow的底层API开发机器学习模型时,需要显式地定义模型中的变量和输入数据、以及对会话进行显式的声明和管理,这需要不小的编码量。而使用更高层的Dataset工具类可以很轻松、高效地处理大量的输入数据以及不同的数据格式,相比基于feed_dict的数据输入方式更加高效和规整。并且使用Estimator工具类可以简化机器学习模型的构建过程,Estimator可以自动管理图的构建、变量初始化、模型保存及恢复过程。对一般机器学习pipeline中的训练、评估、预测三个过程进行统一管理。
0、estimator简介
Estimator封装了模型的构建、训练、评估、预估以及保存过程,将数据的输入从模型中分离出来。数据输入需要编写单独的函数。
1、Tf.estimator使用
- keras转estimator
- 使用预定义的estimator
- BaseLineClassifier
- LinearClassifier
- DNNClassifier
- Tf.feature_column做特征工程
1.1、Dataset和Estimator的完整使用流程
- 定义用于训练和评估的输入函数input_fn;
- 根据数据集的特点,定义好feature_column;
- 使用预定义的Estimator或者自定义模型函数;
- 调用Estimator的train、eval和predict方法产生结果。
2.1、API列表
- Tf.keras.estimator.to_estimator
- Train,evaluate
- Tf.estimator.BaselineClassifier
- Tf.estimator.LinearClassifier
- Tf.estimator.DNNClassifier
- Tf.feature_column
- categorical_column_with_vocabulary_list
- numweric_column
- indicator_column
- cross_column
- keras.layers.DenseFeatures
2、Estimator实战
2.1、feature_column使用
加载库
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf
from tensorflow import keras
#import keras
print(tf.__version__)
print(sys.version_info)
for module in mpl,np,pd,sklearn,tf,keras:
print(module.__name__,module.__version__)
数据加载及处理
# 数据地址
# https://storage.googleapis.com/tf-datasets/titanic/train.csv
# https://storage.googleapis.com/tf-datasets/titanic/eval.csv
train_file = "./data/titanic/train.csv"
eval_file = "./data/titanic/eval.csv"
train_df = pd.read_csv(train_file)
eval_df = pd.read_csv(eval_file)
print(train_df.head())
print(eval_df.head())
#survived是要预测的值,不能在特征值中,所以要选出来。
y_train = train_df.pop('survived')
y_eval = eval_df.pop('survived')
#pop() 函数用于移除列表中的一个元素(默认最后一个元素),并且返回该元素的值
print(train_df.head())
print(eval_df.head())
print(y_train.head())
print(y_eval.head())
<