TensorFlow2.0 (4) dataset 使用

 

摘要:在第一章我们学会了用 TensorFlow2 来构建模型,第二章又学习了超参数搜索,第三章学会了 TensorFlow 基础 api 的实现。但想要训练出更好的模型,还有一步极其关键的步骤,数据的输入与处理,在实际工作项目中,数据的处理与输入甚至可能占用 60% 的时间。

一、 Dataset 基础 API使用

          1.1 tf.data.Dataset.from_tensor_slices

          repeat,batch,interleave,map,shuffle,list_files

               

                我们先使用第一种方法,从内存构建数据

# 日常 import
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import sklearn 
import pandas as pd
import os
import sys
import time
import tensorflow as tf

from tensorflow import keras

               调用 from_tensor_slices

# 初始化一个 1*10 的一维向量的 dataset
dataset = tf.data.Dataset.from_tensor_slices(np.arange(10))
print(dataset)

# 那么对 dataset 我们可以遍历
for item in dataset:
    print(item)

# 重复读取:repeat epoch 遍历一次为一次 epoch
# 一次获取的数据量:get batch
# 遍历三次。每次 7 个
dataset = dataset.repeat(3).batch(7)

# interleave: 对现有的每个元素进行处理,形成一个新的数据集
# case: 做一个变化,把文件名对应的文件内容读取出来,再讲文件内容合并起来,变成一个总的数据集

dataset2 = dataset.interleave(
    lambda v: tf.data.Dataset.from_tensor_slices(v),# map_fn :数据变换形式
    cycle_length = 5,# cycle_length : 同时处理的数据个数
    block_length = 5,# block_length : 每次取多少个元素出来

)

# 然后我们再遍历这个 dataset2
for item in dataset2:
    # 显示每个被 dataset 从 dataset2 获取到的数据
    print(item) 
# 通过这个结果我们会发现,它会在每个数组前取 5 个数,在最后不足 5 个数的时候,会从之前只取 5 个数时遗漏的尾部数据那按顺序提取数据,这样就达到了一个均匀混合的效果

              这里我们输入进去的是 np 的向量,我们尝试一下别的数据类型

              除了支持 numpy 的数据,还支持 python 原有的元组字典等,我们先尝试一下元组

              输入元组 (x,y)

# 除了支持 numpy 的数据,还支持 python 原有的元组字典等
# 设定一个二维矩阵,和二维矩阵对应的对象
x = np.array([1,2],[3,4],[5,6]])
y = np.array(['cat','dog','fox'])
# 这一步就是将两个数组以元组的方式输入
dataset3 = tf.data.Dataset.from_tensor_slices((x,y))
print(dataset3)

for item_x,item_y in dataset3:
    # 这里直接输出 item 的话,是tensor类型,想简单一点的话,像之前所讲,加个 numpy() 就好
    print(item_x.numpy(),item_y.numpy())

             我们再试试字典形式

dataset4 = tf.data.Dataset.from_tensor_slices({"fesature":x,
                                                "label":y})
# 直接输出 item 的字典数据
for item in dataset4:
    print(item)
# 简洁的输出方法
for item in dataset4:
    print(item["feature"].numpy(),item["label"].numpy())

 

二、Dataset 读取 CSV 文件

         2.1 tf.data.TextLineDataset , tf.io.decode_csv

             这里为了读者们的成功运行,挂上完整的前置代码

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf

from tensorflow import keras

print(tf.__version__)
print(sys.version_info)
for module in mpl, np, pd, sklearn, tf, keras:
    print(module.__name__, module.__version__)

from sklearn.datasets import fetch_california_housing

housing = fetch_california_housing()

from sklearn.model_selection import train_test_split

x_train_all, x_test, y_train_all, y_test = train_test_split(
    housing.data, housing.target, random_state = 7)
x_train, x_valid, y_train, y_valid = train_test_split(
    x_train_all, y_train_all, random_state = 11)
print(x_train.shape, y_train.shape)
print(x_valid.shape, y_valid.shape)
print(x_test.shape, y_test.shape)

from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
x_train_scaled = scaler.fit_transform(x_train)
x_valid_scaled = scaler.transform(x_valid)
x_test_scaled = scaler.transform(x_test)

             因为手里没有 csv 文件,我们先把标准化后的数据,输出成 csv 文件

# 输入一个文件夹名称,如果没有这个文件夹,则新建一个
output_dir  = "generate_csv"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

# 因为我们要生成三个文件,分别是 train,valid,test,设定一个输入参数为文件名,对于 csv 文件可能会有一个 header,分割成 10 个文件
def save_to_csv(output_dir,data,name_prefix,header=None,n_parts=10):

    # 第一个 {} 用于填写 train,valid,test,第二个用于存储每个past 的具体数字。只有一位数字的话,会在前面加一个数字
    path_format = os.path.join(output_dir,"{}_{:o2d}.csv")
    filenames = []

    # 新建一个长度与 data 一致的数组,并将它分成 n_parts 份,并给每一份配置一个序号,file_idx 为序号,row_indeices 为内容,也就是行号,与data里的数据的行数对应
    for file_idx,row_indices in enumerate(
    np.array_split(np.arrang(len(data)),n_parts)):

        # 对设定好的路径 path_format 进行实时排序,对 {} 分别插入 文件名与序号
        part_csv = path_format.format(name_prefix, file_idx)

        # 将编辑好的文件名计入到文件名数组内
        filenames.append(part_csv)

        # 以文本方式,写入模式,编码为 utf-8 打开文件 part_csv
        with open(part_csv, "wt", encoding="utf-8") as f:

            # 若有文件头,则在文件中插入该内容,并回车
            if header is not None:
                f.write(heade
  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值