python dataset模块_数据挖掘算法和实践(十三):使用tf.data.DataSet模块处理数据...

目录

类似于numpy中的ndarray数据类型和数据操作,TensorFlow提供了tf.data.DataSet模块,方便地处理数据输入、输出,支持大量的数据计算和转换,tf.data.DataSet中是一个或者多个tensor对象。

一、DataSet的创建:

直接从tensor创建tf.data.DataSet,使用tf.data.DataSet.from_tensor_slices()函数,函数参数可以是python自带数据类型list,或者numpy.ndarray:

# 可以从list,从numpy.ndarray创建 dataset

X= np.array([1.3,4.4,5.5,6.71,6.93,4.168,9.779,6.182,7.59,2.167,7.042,10.791,5.313,7.997,5.654,9.27,3.1])

Y= np.array([[1.3,4.4],[5.5,6.71]])

dataset1=tf.data.Dataset.from_tensor_slices([1,2,3,4]) # list 创建

dataset2=tf.data.Dataset.from_tensor_slices(X) # numpy 创建

dataset2

dataset的类型是tensorslicedataset,可以使用循环查看每个元素都是一个tensor,也可以用numpy方法;

for i in dataset2.take(2):

print(i)

print(i.numpy())

二、DataSet的常用函数:

1、在建模之前可以对数据进行处理,比如:

① shuffle()函数,提供乱序操作;

② repeat()函数,提供数据重复操作;

③ batch() 函数,提供批量读取功能;

X= np.array([1.3,4.4,5.5,6.71,6.93,4.168,9.779,6.182,7.59,2.167,7.042,10.791,5.313,7.997,5.654,9.27,3.1])

dataset2=tf.data.Dataset.from_tensor_slices(X) # numpy 创建

data_shuffle=dataset2.shuffle(3) # 打乱数据

data_repeat=dataset2.repeat(count=2) # 数据重复

data_batch=dataset2.batch(2) # 数据批量读取

2、数据变换,包括map函数

dataset_sq=dataset2.map(tf.square)

三、使用DataSet改写fashion_MNIST分类模型:

与之前处理方式不同在于,建模之前对数据进行了一些变换,并且增加了模型训练过程中的验证数据;

import matplotlib.pyplot as plt

import tensorflow as tf

import pandas as pd

import numpy as np

%matplotlib inline

(train_image,train_lable),(test_image,test_lable)=tf.keras.datasets.fashion_mnist.load_data()

plt.imshow(train_image[11]) # image show

ds_train_image=tf.data.Dataset.from_tensor_slices(train_image) # 加载数据

ds_train_lable=tf.data.Dataset.from_tensor_slices(train_lable) # 加载数据

# 打乱数据,无线重复,成批读取

da_train=tf.data.Dataset.zip((ds_train_image,ds_train_lable)).shuffle(10000).repeat().batch(64)

# 测试数据集

ds_test_image=tf.data.Dataset.from_tensor_slices(test_image) # 加载数据

ds_test_lable=tf.data.Dataset.from_tensor_slices(test_lable) # 加载数据

ds_test=tf.data.Dataset.zip((ds_test_image,ds_test_lable)).batch(64)

model=tf.keras.Sequential([

tf.keras.layers.Flatten(input_shape=(28,28)),

tf.keras.layers.Dense(128,activation="relu"),

tf.keras.layers.Dense(10,activation="softmax")

])

model.compile(optimizer="adam",loss="sparse_categorical_crossentropy",metrics=["acc"])

train_image.shape[0]//64

history=model.fit(da_train,

epochs=5,

steps_per_epoch=train_image.shape[0]//64,

validation_data=ds_test,

validation_steps=test_image.shape[0]//64

)

model.evaluate(test_image,test_lable)

plt.plot(history.epoch,history.history.get('loss'))

plt.plot(history.epoch,history.history.get('acc'))

原文链接:https://blog.csdn.net/yezonggang/article/details/106490552

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
这段代码是 Python 中的一些导入语句,它们用于导入一些常用的库和模块,以便在后续的代码中使用它们。具体来说: - os:提供了访问操作系统功能的接口,如文件系统、进程管理等。 - yaml:提供了读取和写入 YAML 格式文件的功能。 - pickle:提供了将 Python 对象序列化和反序列化的功能。 - shutil:提供了高级的文件操作功能,如复制、移动、删除等。 - tarfile:提供了访问 tar 文件的功能。 - glob:提供了匹配文件路径的功能。 - cv2:OpenCV 库,提供了计算机视觉相关的功能,如图像处理、计算机视觉算法等。 - albumentations:提供了数据增强相关的功能,如随机裁剪、旋转、缩放等。 - PIL:Python Imaging Library,提供了图像处理相关的功能,如图像缩放、旋转、裁剪等。 - numpy:提供了高性能的数值计算功能。 - torchvision.transforms.functional:提供了图像变换的功能,如旋转、裁剪、翻转等。 - OmegaConf:提供了配置文件的读取和解析功能。 - partial:提供了创建一个新函数的功能,该新函数是原函数的一个部分应用。 - Image:PIL 库中的一个类,用于表示图像。 - tqdm:提供了进度条功能,用于显示任务执行的进度。 - Dataset:PyTorch 中的一个抽象类,用于表示数据集。 - Subset:PyTorch 中的一个类,用于表示数据集的子集。 - taming.data.utils:taming data 包中的一个模块,提供了一些数据处理相关的函数。 - taming.data.imagenet:taming data 包中的一个模块,提供了 ImageNet 数据集的相关函数。 - str_to_indices:将 ImageNet 数据集中的类别名称转换为对应的类别索引。 - give_synsets_from_indices:根据 ImageNet 类别索引获取对应的 synset。 - download:下载 ImageNet 数据集。 - retrieve:从 ImageNet 数据集中提取图像。 - ImagePaths:表示 ImageNet 数据集中图像的路径。 - degradation_fn_bsr:图像降质函数,用于生成降质后的图像。 - degradation_fn_bsr_light:轻量级的图像降质函数。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值