全篇以tensorflow案例作为来源
这篇教程使用的是泰坦尼克号乘客的数据。模型会根据乘客的年龄、性别、票务舱和是否独自旅行等特征来预测乘客生还的可能性。
首先导入需要的模块
import functools
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
TRAIN_DATA_URL和TEST_DATA_URL是数据集的网址,使用tf.keras.utils.get_file从TRAIN_DATA_URL下载数据集train.cs,同时给下载的文件命名为train.csv。
train_file_path是该文件被保存在计算机中的路径。
TRAIN_DATA_URL="https://storage.googleapis.com/tfdatasets/titanic/ train.cs "
TEST_DATA_URL = "https://storage.googleapis.com/tf-datasets/titanic/eval.csv"
train_file_path = tf.keras.utils.get_file("train.csv", TRAIN_DATA_URL)
test_file_path = tf.keras.utils.get_file("eval.csv", TEST_DATA_URL)
np.set_printoptions用于控制输出方式,precision控制小数点后输出个数,默认是8. Suppress表示小数是否以科学计数法方式输出。
np.set_printoptions(precision=3, suppress=True)表示保留3位小数,小数不需要以科学计数法形式输出。
# 让 numpy 数据更易读。
np.set_printoptions(precision=3, suppress=True)
查看train.csv文件,如果没有列名,那么需要将列名通过字符串列表传给 make_csv_dataset 函数的 column_names 参数。
CSV_COLUMNS = ['survived', 'sex', 'age', 'n_siblings_spouses', 'parch', 'fare', 'class', 'deck', 'embark_town', 'alone']
dataset = tf.data.experimental.make_csv_dataset(
column_names=CSV_COLUMNS,)
如果只需要某些列,可以使用select_columns函数。
dataset = tf.data.experimental.make_csv_dataset(
select_columns