TensorFlow2.0入门到进阶系列——5.1_TensorFlow进阶篇

本文详细介绍了TensorFlow2.0中的Estimator,包括预定义和自定义Estimator的使用,特别是Tf.feature_column在特征工程中的应用。通过实例展示了从Keras模型转换为Estimator的过程,并对比了TensorFlow1.0与2.0的API变化,指导如何将旧代码升级到TF2.0。
摘要由CSDN通过智能技术生成


使用TensorFlow的底层API开发机器学习模型时,需要显式地定义模型中的变量和输入数据、以及对会话进行显式的声明和管理,这需要不小的编码量。而使用更高层的Dataset工具类可以很轻松、高效地处理大量的输入数据以及不同的数据格式,相比基于feed_dict的数据输入方式更加高效和规整。并且使用Estimator工具类可以简化机器学习模型的构建过程,Estimator可以自动管理图的构建、变量初始化、模型保存及恢复过程。对一般机器学习pipeline中的训练、评估、预测三个过程进行统一管理。

0、estimator简介

Estimator封装了模型的构建、训练、评估、预估以及保存过程,将数据的输入从模型中分离出来。数据输入需要编写单独的函数。
在这里插入图片描述

1、Tf.estimator使用

  • keras转estimator
  • 使用预定义的estimator
    • BaseLineClassifier
    • LinearClassifier
    • DNNClassifier
  • Tf.feature_column做特征工程

1.1、Dataset和Estimator的完整使用流程

  • 定义用于训练和评估的输入函数input_fn;
  • 根据数据集的特点,定义好feature_column;
  • 使用预定义的Estimator或者自定义模型函数;
  • 调用Estimator的train、eval和predict方法产生结果。

2.1、API列表

  • Tf.keras.estimator.to_estimator
    • Train,evaluate
  • Tf.estimator.BaselineClassifier
  • Tf.estimator.LinearClassifier
  • Tf.estimator.DNNClassifier
  • Tf.feature_column
    • categorical_column_with_vocabulary_list
    • numweric_column
    • indicator_column
    • cross_column
  • keras.layers.DenseFeatures

2、Estimator实战

2.1、feature_column使用

加载库

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf

from tensorflow import keras
#import keras

print(tf.__version__)
print(sys.version_info)
for module in mpl,np,pd,sklearn,tf,keras:
    print(module.__name__,module.__version__)

数据加载及处理

# 数据地址
# https://storage.googleapis.com/tf-datasets/titanic/train.csv
# https://storage.googleapis.com/tf-datasets/titanic/eval.csv
train_file = "./data/titanic/train.csv"
eval_file = "./data/titanic/eval.csv"

train_df = pd.read_csv(train_file)
eval_df = pd.read_csv(eval_file)

print(train_df.head())
print(eval_df.head())
#survived是要预测的值,不能在特征值中,所以要选出来。
y_train = train_df.pop('survived')
y_eval = eval_df.pop('survived')
#pop() 函数用于移除列表中的一个元素(默认最后一个元素),并且返回该元素的值

print(train_df.head())
print(eval_df.head())
print(y_train.head())
print(y_eval.head())
<
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值